learning_rate_scheduler.py 21.0 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yuyang18 已提交
14 15 16 17 18 19 20 21
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
Q
Qiao Longfei 已提交
22

23 24
from __future__ import print_function

25
import math
Q
qingqing01 已提交
26
import numbers
27

28 29 30 31
from . import control_flow
from . import nn
from . import ops
from . import tensor
32
from ..framework import default_main_program, Parameter, unique_name, name_scope
Q
qingqing01 已提交
33
from ..framework import Variable
J
Jiabin Yang 已提交
34
from ..framework import _non_static_mode
M
minqiyang 已提交
35
from ..dygraph import learning_rate_scheduler as imperate_lr
36
from ..data_feeder import check_variable_and_dtype, check_type
Q
Qiao Longfei 已提交
37

38 39
__all__ = [
    'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
40 41
    'polynomial_decay', 'piecewise_decay', 'noam_decay', 'cosine_decay',
    'linear_lr_warmup'
42
]
Q
Qiao Longfei 已提交
43 44


45
def _decay_step_counter(begin=0):
Y
Yu Yang 已提交
46
    # the first global step is zero in learning rate decay
47
    global_step = nn.autoincreased_step_counter(
48
        counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1)
49
    global_step = tensor.cast(global_step, 'float32')
Y
Yu Yang 已提交
50 51 52
    return global_step


53
def noam_decay(d_model, warmup_steps, learning_rate=1.0):
Y
yuyang18 已提交
54
    """
S
swtkiwi 已提交
55

Y
yuyang18 已提交
56 57
    Noam decay method. The numpy implementation of noam decay as follows.

X
xiaoting 已提交
58 59
    .. code-block:: python
      
60
      import paddle.fluid as fluid
X
xiaoting 已提交
61 62
      import numpy as np
      # set hyper parameters
63
      base_lr = 0.01
X
xiaoting 已提交
64 65 66 67
      d_model = 2
      current_steps = 20
      warmup_steps = 200
      # compute
68
      lr_value = base_lr * np.power(d_model, -0.5) * np.min([
X
xiaoting 已提交
69 70
                              np.power(current_steps, -0.5),
                              np.power(warmup_steps, -1.5) * current_steps])
Y
yuyang18 已提交
71 72 73

    Please reference `attention is all you need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.
74 75 76

    Args:
        d_model(Variable): The dimensionality of input and output of model.
Y
yuyang18 已提交
77

78 79
        warmup_steps(Variable): A super parameter.

80 81 82 83
        learning_rate(Variable|float|int): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
            float32 or float64. It also can be set to python int number. Default 1.0

84 85
    Returns:
        The decayed learning rate.
X
xiaoting 已提交
86 87 88
    Examples:
        .. code-block:: python

89
          import paddle.fluid as fluid
X
xiaoting 已提交
90 91 92 93
          warmup_steps = 100
          learning_rate = 0.01
          lr = fluid.layers.learning_rate_scheduler.noam_decay(
                         1/(warmup_steps *(learning_rate ** 2)),
94 95
                         warmup_steps,
                         learning_rate)
96
    """
97
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
98
        if _non_static_mode():
99 100 101
            decay = imperate_lr.NoamDecay(d_model,
                                          warmup_steps,
                                          learning_rate=learning_rate)
M
minqiyang 已提交
102 103 104
            return decay
        else:
            global_step = _decay_step_counter(1)
F
fengjiayi 已提交
105

M
minqiyang 已提交
106 107
            a = global_step**-0.5
            b = (warmup_steps**-1.5) * global_step
108 109
            lr_value = learning_rate * (d_model**-0.5) * nn.elementwise_min(
                a, b)
110

M
minqiyang 已提交
111
            return lr_value
112 113


Y
Yu Yang 已提交
114
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
115
    """
S
swtkiwi 已提交
116

117
    Applies exponential decay to the learning rate.
F
fengjiayi 已提交
118

119 120
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
F
fengjiayi 已提交
121 122
    'decay_rate' every 'decay_steps' steps.

T
tianshuo78520a 已提交
123
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
124

F
fengjiayi 已提交
125 126 127 128
    >>> if staircase == True:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps)
    >>> else:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
Q
Qiao Longfei 已提交
129 130

    Args:
K
Kaipeng Deng 已提交
131 132 133 134 135 136 137 138
        learning_rate(Variable|float): The initial learning rate. It should be a Variable 
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
        staircase(bool): If True, decay the learning rate at discrete intervals, which 
                         means the learning rate will be decayed by `decay_rate` every
                         `decay_steps`. If False, learning rate will be decayed continuously
                         and following the formula above. Default: False
Q
Qiao Longfei 已提交
139 140

    Returns:
K
Kaipeng Deng 已提交
141
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
142 143 144 145

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
146
          import paddle.fluid as fluid
147 148 149
          import paddle

          paddle.enable_static()
F
fengjiayi 已提交
150 151
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
K
Kaipeng Deng 已提交
152 153 154 155 156
	      learning_rate=fluid.layers.exponential_decay(
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))
F
fengjiayi 已提交
157

Q
Qiao Longfei 已提交
158
    """
159
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
160
        if _non_static_mode():
161 162 163 164 165
            decay = imperate_lr.ExponentialDecay(learning_rate, decay_steps,
                                                 decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
166

167 168 169 170
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
            decayed_lr = learning_rate * (decay_rate**div_res)
171

172
            return decayed_lr
Q
Qiao Longfei 已提交
173 174


Y
Yu Yang 已提交
175
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
S
swtkiwi 已提交
176 177 178
    """

Applies natural exponential decay to the initial learning rate.
Q
Qiao Longfei 已提交
179

K
Kaipeng Deng 已提交
180 181 182 183
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
    natural exponential power 'decay_rate' every 'decay_steps' steps.

T
tianshuo78520a 已提交
184
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
185

Y
Yu Yang 已提交
186 187 188
    >>> if not staircase:
    >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
    >>> else:
189
    >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * floor(global_step / decay_steps))
Y
Yu Yang 已提交
190

Q
Qiao Longfei 已提交
191
    Args:
K
Kaipeng Deng 已提交
192 193 194 195 196
        learning_rate(Variable|float): The initial learning rate. It should be a Variable 
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
        staircase(bool): If True, decay the learning rate at discrete intervals, which 
T
tianshuo78520a 已提交
197
                         means the learning rate will be decayed by natural exponential power
K
Kaipeng Deng 已提交
198 199
                         `decay_rate` every `decay_steps`. If False, learning rate will be
                         decayed continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
200 201

    Returns:
K
Kaipeng Deng 已提交
202
        The decayed learning rate. The data type is float32.
K
Kaipeng Deng 已提交
203 204 205 206 207

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
208 209 210
          import paddle

          paddle.enable_static()
K
Kaipeng Deng 已提交
211 212 213 214 215 216 217 218
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
	      learning_rate=fluid.layers.natural_exp_decay(
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))

Q
Qiao Longfei 已提交
219
    """
220
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
221
        if _non_static_mode():
222 223 224 225 226
            decay = imperate_lr.NaturalExpDecay(learning_rate, decay_steps,
                                                decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
227

228 229 230 231
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
            decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
232

233
            return decayed_lr
Q
Qiao Longfei 已提交
234 235


Y
Yu Yang 已提交
236
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
237
    """
S
swtkiwi 已提交
238

F
fengjiayi 已提交
239
    Applies inverse time decay to the initial learning rate.
Q
Qiao Longfei 已提交
240

241 242
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, an inverse decay function will be
F
fengjiayi 已提交
243
    applied to the initial learning rate.
Q
Qiao Longfei 已提交
244

T
tianshuo78520a 已提交
245
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
246

F
fengjiayi 已提交
247
    >>> if staircase == True:
Y
Yu Yang 已提交
248 249 250 251
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
    >>> else:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)

Q
Qiao Longfei 已提交
252
    Args:
K
Kaipeng Deng 已提交
253 254 255 256 257 258 259 260
        learning_rate(Variable|float): The initial learning rate. It should be a Variable 
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
        staircase(bool): If True, decay the learning rate at discrete intervals, which 
                         means the learning rate will be decayed by `decay_rate` times 
                         every `decay_steps`. If False, learning rate will be decayed 
                         continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
261 262

    Returns:
K
Kaipeng Deng 已提交
263
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
264 265 266 267

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
268
          import paddle.fluid as fluid
269 270
          import paddle
          paddle.enable_static()
F
fengjiayi 已提交
271 272
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
K
Kaipeng Deng 已提交
273
	      learning_rate=fluid.layers.inverse_time_decay(
K
Kaipeng Deng 已提交
274 275 276 277
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))
Q
Qiao Longfei 已提交
278
    """
279
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
280
        if _non_static_mode():
281 282 283 284 285
            decay = imperate_lr.InverseTimeDecay(learning_rate, decay_steps,
                                                 decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
286

287 288 289
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
290

291
            decayed_lr = learning_rate / (1 + decay_rate * div_res)
Q
Qiao Longfei 已提交
292

293
            return decayed_lr
294 295 296 297 298 299 300


def polynomial_decay(learning_rate,
                     decay_steps,
                     end_learning_rate=0.0001,
                     power=1.0,
                     cycle=False):
Q
qiaolongfei 已提交
301 302 303
    """
    Applies polynomial decay to the initial learning rate.

X
xiaoting 已提交
304
    .. code-block:: text
Q
qiaolongfei 已提交
305 306 307 308 309 310 311

     if cycle:
       decay_steps = decay_steps * ceil(global_step / decay_steps)
     else:
       global_step = min(global_step, decay_steps)
       decayed_learning_rate = (learning_rate - end_learning_rate) *
            (1 - global_step / decay_steps) ^ power + end_learning_rate
312 313

    Args:
Q
qiaolongfei 已提交
314
        learning_rate(Variable|float32): A scalar float32 value or a Variable. This
Q
update  
qiaolongfei 已提交
315
          will be the initial learning rate during training.
Q
qiaolongfei 已提交
316
        decay_steps(int32): A Python `int32` number.
Q
update  
qiaolongfei 已提交
317 318 319
        end_learning_rate(float): A Python `float` number.
        power(float): A Python `float` number.
        cycle(bool): If set true, decay the learning rate every decay_steps.
320 321

    Returns:
Q
update  
qiaolongfei 已提交
322
        Variable: The decayed learning rate
X
xiaoting 已提交
323 324 325 326 327 328 329 330 331 332 333

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          lr = fluid.layers.polynomial_decay(
              start_lr, total_step, end_lr, power=1)

334
    """
335
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
336
        if _non_static_mode():
337 338 339
            decay = imperate_lr.PolynomialDecay(learning_rate, decay_steps,
                                                end_learning_rate, power, cycle)
            return decay
340
        else:
341 342 343 344
            global_step = _decay_step_counter()

            if cycle:
                div_res = ops.ceil(global_step / decay_steps)
345 346 347 348 349 350
                zero_var = tensor.fill_constant(shape=[1],
                                                dtype='float32',
                                                value=0.0)
                one_var = tensor.fill_constant(shape=[1],
                                               dtype='float32',
                                               value=1.0)
351 352 353 354 355 356

                with control_flow.Switch() as switch:
                    with switch.case(global_step == zero_var):
                        tensor.assign(input=one_var, output=div_res)
                decay_steps = decay_steps * div_res
            else:
357 358 359 360 361
                decay_steps_var = tensor.fill_constant(shape=[1],
                                                       dtype='float32',
                                                       value=float(decay_steps))
                global_step = nn.elementwise_min(x=global_step,
                                                 y=decay_steps_var)
362

363 364 365
            decayed_lr = (learning_rate - end_learning_rate) * \
                ((1 - global_step / decay_steps) ** power) + end_learning_rate
            return decayed_lr
366 367


Y
Yu Yang 已提交
368
def piecewise_decay(boundaries, values):
S
swtkiwi 已提交
369 370 371
    """

Applies piecewise decay to the initial learning rate.
372

X
xiaoting 已提交
373
    The algorithm can be described as the code below.
X
Xin Pan 已提交
374

X
xiaoting 已提交
375
    .. code-block:: text
X
Xin Pan 已提交
376

X
xiaoting 已提交
377 378 379 380 381 382 383 384
      boundaries = [10000, 20000]
      values = [1.0, 0.5, 0.1]
      if step < 10000:
          learning_rate = 1.0
      elif 10000 <= step < 20000:
          learning_rate = 0.5
      else:
          learning_rate = 0.1
X
Xin Pan 已提交
385 386 387 388 389 390 391 392
    Args:
        boundaries: A list of steps numbers.
        values: A list of learning rate values that will be picked during
            different step boundaries.

    Returns:
        The decayed learning rate.

X
xiaoting 已提交
393 394 395 396
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
397 398
          import paddle
          paddle.enable_static()
X
xiaoting 已提交
399 400 401 402 403 404 405
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
          optimizer = fluid.optimizer.Momentum(
              momentum=0.9,
              learning_rate=fluid.layers.piecewise_decay(boundaries=boundaries, values=values),
              regularization=fluid.regularizer.L2Decay(1e-4))

X
Xin Pan 已提交
406

407
    """
408 409 410 411
    with default_main_program()._lr_schedule_guard():
        if len(values) - len(boundaries) != 1:
            raise ValueError("len(values) - len(boundaries) should be 1")

J
Jiabin Yang 已提交
412
        if _non_static_mode():
M
minqiyang 已提交
413
            decay = imperate_lr.PiecewiseDecay(boundaries, values, 0)
414 415 416
            return decay
        else:
            global_step = _decay_step_counter()
417

418 419 420 421 422
            lr = tensor.create_global_var(shape=[1],
                                          value=0.0,
                                          dtype='float32',
                                          persistable=True,
                                          name="learning_rate")
423

424 425
            with control_flow.Switch() as switch:
                for i in range(len(boundaries)):
426 427 428 429 430
                    boundary_val = tensor.fill_constant(shape=[1],
                                                        dtype='float32',
                                                        value=float(
                                                            boundaries[i]),
                                                        force_cpu=True)
431
                    with switch.case(global_step < boundary_val):
432 433 434 435
                        tensor.fill_constant(shape=[1],
                                             dtype="float32",
                                             value=float(values[i]),
                                             out=lr)
436
                with switch.default():
437 438 439 440
                    tensor.fill_constant(shape=[1],
                                         dtype="float32",
                                         value=float(values[len(values) - 1]),
                                         out=lr)
441

442
            return lr
W
Wu Yi 已提交
443 444


S
shippingwang 已提交
445
def cosine_decay(learning_rate, step_each_epoch, epochs):
446
    r"""
S
swtkiwi 已提交
447

S
shippingwang 已提交
448 449
    Applies cosine decay to the learning rate.

S
shippingwang 已提交
450
    when training a model, it is often recommended to lower the learning rate as the
S
shippingwang 已提交
451 452
    training progresses. By using this function, the learning rate will be decayed by
    following cosine decay strategy.
S
shippingwang 已提交
453

454 455
    .. math::

X
xsrobin 已提交
456 457
        decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1)

S
shippingwang 已提交
458 459 460 461 462
    Args:
        learning_rate(Variable|float): The initial learning rate.
        step_each_epoch(int): the number of steps in an epoch.
        epochs(int): the number of epochs.

463
    Returns:
X
xsrobin 已提交
464
        Variable: The decayed learning rate.
S
shippingwang 已提交
465

466
    Examples:
X
xsrobin 已提交
467
        .. code-block:: python
S
shippingwang 已提交
468

X
xsrobin 已提交
469 470 471 472
            import paddle.fluid as fluid
            base_lr = 0.1
            lr = fluid.layers.cosine_decay(
            learning_rate = base_lr, step_each_epoch=10000, epochs=120)
S
shippingwang 已提交
473
    """
474 475
    check_type(learning_rate, 'learning_rate', (float, tensor.Variable),
               'cosine_decay')
476

S
shippingwang 已提交
477
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
478
        if _non_static_mode():
M
minqiyang 已提交
479 480 481 482 483
            decay = imperate_lr.CosineDecay(learning_rate, step_each_epoch,
                                            epochs)
            return decay
        else:
            global_step = _decay_step_counter()
S
shippingwang 已提交
484

M
minqiyang 已提交
485 486 487 488
            cur_epoch = ops.floor(global_step / step_each_epoch)
            decayed_lr = learning_rate * 0.5 * (
                ops.cos(cur_epoch * math.pi / epochs) + 1)
            return decayed_lr
S
shippingwang 已提交
489 490


491 492
def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
    """
S
swtkiwi 已提交
493

494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
    This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
    
    When global_step < warmup_steps, learning rate is updated as:
    
    .. code-block:: text
    
            linear_step = end_lr - start_lr
            lr = start_lr + linear_step * (global_step / warmup_steps)
    
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
    
    When global_step >= warmup_steps, learning rate is updated as:
    
    .. code-block:: text
    
            lr = learning_rate
    
    where lr is the learning_rate after warm-up.
    
514
    Args:
515 516 517 518 519
        learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
        warmup_steps (int): Steps for warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
    
520
    Returns:
521 522 523
        Variable: Warm-up learning rate with the same data type as learning_rate.
    
    
524
    Examples:
525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
    
    .. code-block:: python
    
        import paddle.fluid as fluid
    
        boundaries = [100, 200]
        lr_steps = [0.1, 0.01, 0.001]
        learning_rate = fluid.layers.piecewise_decay(boundaries, lr_steps) #case1, 1D-Tensor
        #learning_rate = 0.1  #case2, single-value
        warmup_steps = 50
        start_lr = 1. / 3.
        end_lr = 0.1
        decayed_lr = fluid.layers.linear_lr_warmup(learning_rate,
            warmup_steps, start_lr, end_lr)
    
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        out, = exe.run(fetch_list=[decayed_lr.name])
        print(out)
        # case1: [0.33333334]
        # case2: [0.33333334]
547
    """
Q
qingqing01 已提交
548 549 550 551 552
    dtype = 'float32'
    if isinstance(learning_rate, Variable):
        dtype = learning_rate.dtype

    linear_step = float(end_lr) - float(start_lr)
553
    with default_main_program()._lr_schedule_guard():
H
hong 已提交
554

J
Jiabin Yang 已提交
555
        if _non_static_mode():
H
hong 已提交
556 557 558 559
            lr = imperate_lr.LinearLrWarmup(learning_rate, warmup_steps,
                                            start_lr, end_lr)
            return lr
        else:
560 561 562 563 564
            lr = tensor.create_global_var(shape=[1],
                                          value=0.0,
                                          dtype=dtype,
                                          persistable=True,
                                          name="learning_rate_warmup")
H
hong 已提交
565 566 567 568 569 570 571 572 573 574 575 576 577 578

            global_step = _decay_step_counter()

            with control_flow.Switch() as switch:
                with switch.case(global_step < warmup_steps):
                    decayed_lr = start_lr + linear_step * (global_step /
                                                           float(warmup_steps))
                    tensor.assign(decayed_lr, lr)
                with switch.default():
                    if not isinstance(learning_rate, Variable):
                        learning_rate = tensor.fill_constant(
                            shape=[1], dtype=dtype, value=float(learning_rate))
                    tensor.assign(learning_rate, lr)
            return lr