learning_rate_scheduler.py 16.9 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yuyang18 已提交
14 15 16 17 18 19 20 21
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
Q
Qiao Longfei 已提交
22

23 24
from __future__ import print_function

25
import math
Q
qingqing01 已提交
26
import numbers
27

28 29 30 31
from . import control_flow
from . import nn
from . import ops
from . import tensor
32
from ..initializer import init_on_cpu
33
from ..framework import default_main_program, Parameter, unique_name, name_scope
Q
qingqing01 已提交
34
from ..framework import Variable
M
minqiyang 已提交
35 36
from ..dygraph import base as imperative_base
from ..dygraph import learning_rate_scheduler as imperate_lr
Q
Qiao Longfei 已提交
37

38 39
__all__ = [
    'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
40 41
    'polynomial_decay', 'piecewise_decay', 'noam_decay', 'cosine_decay',
    'linear_lr_warmup'
42
]
Q
Qiao Longfei 已提交
43 44


45
def _decay_step_counter(begin=0):
Y
Yu Yang 已提交
46
    # the first global step is zero in learning rate decay
47
    global_step = nn.autoincreased_step_counter(
48
        counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1)
49
    global_step = tensor.cast(global_step, 'float32')
Y
Yu Yang 已提交
50 51 52
    return global_step


53
def noam_decay(d_model, warmup_steps):
Y
yuyang18 已提交
54 55 56
    """
    Noam decay method. The numpy implementation of noam decay as follows.

X
xiaoting 已提交
57 58
    .. code-block:: python
      
59
      import padde.fluid as fluid
X
xiaoting 已提交
60 61 62 63 64 65 66 67 68
      import numpy as np
      # set hyper parameters
      d_model = 2
      current_steps = 20
      warmup_steps = 200
      # compute
      lr_value = np.power(d_model, -0.5) * np.min([
                              np.power(current_steps, -0.5),
                              np.power(warmup_steps, -1.5) * current_steps])
Y
yuyang18 已提交
69 70 71

    Please reference `attention is all you need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.
72 73 74

    Args:
        d_model(Variable): The dimensionality of input and output of model.
Y
yuyang18 已提交
75

76 77 78 79
        warmup_steps(Variable): A super parameter.

    Returns:
        The decayed learning rate.
X
xiaoting 已提交
80 81 82 83 84 85 86 87 88
    Examples:
        .. code-block:: python

          import padde.fluid as fluid
          warmup_steps = 100
          learning_rate = 0.01
          lr = fluid.layers.learning_rate_scheduler.noam_decay(
                         1/(warmup_steps *(learning_rate ** 2)),
                         warmup_steps)
89
    """
90
    with default_main_program()._lr_schedule_guard():
M
minqiyang 已提交
91 92 93 94 95
        if imperative_base.enabled():
            decay = imperate_lr.NoamDecay(d_model, warmup_steps)
            return decay
        else:
            global_step = _decay_step_counter(1)
F
fengjiayi 已提交
96

M
minqiyang 已提交
97 98 99
            a = global_step**-0.5
            b = (warmup_steps**-1.5) * global_step
            lr_value = (d_model**-0.5) * nn.elementwise_min(a, b)
100

M
minqiyang 已提交
101
            return lr_value
102 103


Y
Yu Yang 已提交
104
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
105
    """
106
    Applies exponential decay to the learning rate.
F
fengjiayi 已提交
107

108 109
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
F
fengjiayi 已提交
110 111 112 113 114 115
    'decay_rate' every 'decay_steps' steps.

    >>> if staircase == True:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps)
    >>> else:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
Q
Qiao Longfei 已提交
116 117

    Args:
F
fengjiayi 已提交
118 119 120 121 122
        learning_rate(Variable|float): The initial learning rate.
        decay_steps(int): See the decay computation above.
        decay_rate(float): The decay rate. See the decay computation above.
        staircase(Boolean): If True, decay the learning rate at discrete intervals.
                            Default: False
Q
Qiao Longfei 已提交
123 124

    Returns:
F
fengjiayi 已提交
125
        Variable: The decayed learning rate
F
fengjiayi 已提交
126 127 128 129

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
130
          import paddle.fluid as fluid
F
fengjiayi 已提交
131 132
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
K
Kaipeng Deng 已提交
133 134 135 136 137
	      learning_rate=fluid.layers.exponential_decay(
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))
F
fengjiayi 已提交
138

Q
Qiao Longfei 已提交
139
    """
140
    with default_main_program()._lr_schedule_guard():
141 142 143 144 145 146
        if imperative_base.enabled():
            decay = imperate_lr.ExponentialDecay(learning_rate, decay_steps,
                                                 decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
147

148 149 150 151
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
            decayed_lr = learning_rate * (decay_rate**div_res)
152

153
            return decayed_lr
Q
Qiao Longfei 已提交
154 155


Y
Yu Yang 已提交
156
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
Q
Qiao Longfei 已提交
157 158
    """Applies natural exponential decay to the initial learning rate.

Y
Yu Yang 已提交
159 160 161
    >>> if not staircase:
    >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
    >>> else:
162
    >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * floor(global_step / decay_steps))
Y
Yu Yang 已提交
163

Q
Qiao Longfei 已提交
164 165 166 167 168 169 170 171 172
    Args:
        learning_rate: A scalar float32 value or a Variable. This
          will be the initial learning rate during training
        decay_steps: A Python `int32` number.
        decay_rate: A Python `float` number.
        staircase: Boolean. If set true, decay the learning rate every decay_steps.

    Returns:
        The decayed learning rate
K
Kaipeng Deng 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
	      learning_rate=fluid.layers.natural_exp_decay(
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))

Q
Qiao Longfei 已提交
186
    """
187
    with default_main_program()._lr_schedule_guard():
188 189 190 191 192 193
        if imperative_base.enabled():
            decay = imperate_lr.NaturalExpDecay(learning_rate, decay_steps,
                                                decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
194

195 196 197 198
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
            decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
199

200
            return decayed_lr
Q
Qiao Longfei 已提交
201 202


Y
Yu Yang 已提交
203
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
204 205
    """
    Applies inverse time decay to the initial learning rate.
Q
Qiao Longfei 已提交
206

207 208
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, an inverse decay function will be
F
fengjiayi 已提交
209
    applied to the initial learning rate.
Q
Qiao Longfei 已提交
210

F
fengjiayi 已提交
211
    >>> if staircase == True:
Y
Yu Yang 已提交
212 213 214 215
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
    >>> else:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)

Q
Qiao Longfei 已提交
216
    Args:
F
fengjiayi 已提交
217 218 219 220 221
        learning_rate(Variable|float): The initial learning rate.
        decay_steps(int): See the decay computation above.
        decay_rate(float): The decay rate. See the decay computation above.
        staircase(Boolean): If True, decay the learning rate at discrete intervals.
                            Default: False
Q
Qiao Longfei 已提交
222 223

    Returns:
F
fengjiayi 已提交
224
        Variable: The decayed learning rate
F
fengjiayi 已提交
225 226 227 228

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
229
          import paddle.fluid as fluid
F
fengjiayi 已提交
230 231
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
K
Kaipeng Deng 已提交
232 233 234 235 236
	      learning_rate=fluid.layers.natural_exp_decay(
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))
Q
Qiao Longfei 已提交
237
    """
238
    with default_main_program()._lr_schedule_guard():
239 240 241 242 243 244
        if imperative_base.enabled():
            decay = imperate_lr.InverseTimeDecay(learning_rate, decay_steps,
                                                 decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
245

246 247 248
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
249

250
            decayed_lr = learning_rate / (1 + decay_rate * div_res)
Q
Qiao Longfei 已提交
251

252
            return decayed_lr
253 254 255 256 257 258 259


def polynomial_decay(learning_rate,
                     decay_steps,
                     end_learning_rate=0.0001,
                     power=1.0,
                     cycle=False):
Q
qiaolongfei 已提交
260 261 262
    """
    Applies polynomial decay to the initial learning rate.

X
xiaoting 已提交
263
    .. code-block:: text
Q
qiaolongfei 已提交
264 265 266 267 268 269 270

     if cycle:
       decay_steps = decay_steps * ceil(global_step / decay_steps)
     else:
       global_step = min(global_step, decay_steps)
       decayed_learning_rate = (learning_rate - end_learning_rate) *
            (1 - global_step / decay_steps) ^ power + end_learning_rate
271 272

    Args:
Q
qiaolongfei 已提交
273
        learning_rate(Variable|float32): A scalar float32 value or a Variable. This
Q
update  
qiaolongfei 已提交
274
          will be the initial learning rate during training.
Q
qiaolongfei 已提交
275
        decay_steps(int32): A Python `int32` number.
Q
update  
qiaolongfei 已提交
276 277 278
        end_learning_rate(float): A Python `float` number.
        power(float): A Python `float` number.
        cycle(bool): If set true, decay the learning rate every decay_steps.
279 280

    Returns:
Q
update  
qiaolongfei 已提交
281
        Variable: The decayed learning rate
X
xiaoting 已提交
282 283 284 285 286 287 288 289 290 291 292

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          lr = fluid.layers.polynomial_decay(
              start_lr, total_step, end_lr, power=1)

293
    """
294
    with default_main_program()._lr_schedule_guard():
295 296 297 298
        if imperative_base.enabled():
            decay = imperate_lr.PolynomialDecay(learning_rate, decay_steps,
                                                end_learning_rate, power, cycle)
            return decay
299
        else:
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
            global_step = _decay_step_counter()

            if cycle:
                div_res = ops.ceil(global_step / decay_steps)
                zero_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=0.0)
                one_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=1.0)

                with control_flow.Switch() as switch:
                    with switch.case(global_step == zero_var):
                        tensor.assign(input=one_var, output=div_res)
                decay_steps = decay_steps * div_res
            else:
                decay_steps_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=float(decay_steps))
                global_step = nn.elementwise_min(
                    x=global_step, y=decay_steps_var)
318

319 320 321
            decayed_lr = (learning_rate - end_learning_rate) * \
                ((1 - global_step / decay_steps) ** power) + end_learning_rate
            return decayed_lr
322 323


Y
Yu Yang 已提交
324
def piecewise_decay(boundaries, values):
325 326
    """Applies piecewise decay to the initial learning rate.

X
xiaoting 已提交
327
    The algorithm can be described as the code below.
X
Xin Pan 已提交
328

X
xiaoting 已提交
329
    .. code-block:: text
X
Xin Pan 已提交
330

X
xiaoting 已提交
331 332 333 334 335 336 337 338
      boundaries = [10000, 20000]
      values = [1.0, 0.5, 0.1]
      if step < 10000:
          learning_rate = 1.0
      elif 10000 <= step < 20000:
          learning_rate = 0.5
      else:
          learning_rate = 0.1
X
Xin Pan 已提交
339 340 341 342 343 344 345 346
    Args:
        boundaries: A list of steps numbers.
        values: A list of learning rate values that will be picked during
            different step boundaries.

    Returns:
        The decayed learning rate.

X
xiaoting 已提交
347 348 349 350 351 352 353 354 355 356 357
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
          optimizer = fluid.optimizer.Momentum(
              momentum=0.9,
              learning_rate=fluid.layers.piecewise_decay(boundaries=boundaries, values=values),
              regularization=fluid.regularizer.L2Decay(1e-4))

X
Xin Pan 已提交
358

359
    """
360 361 362 363
    with default_main_program()._lr_schedule_guard():
        if len(values) - len(boundaries) != 1:
            raise ValueError("len(values) - len(boundaries) should be 1")

364
        if imperative_base.enabled():
M
minqiyang 已提交
365
            decay = imperate_lr.PiecewiseDecay(boundaries, values, 0)
366 367 368
            return decay
        else:
            global_step = _decay_step_counter()
369

370 371 372 373 374 375
            lr = tensor.create_global_var(
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate")
376

377 378 379 380 381 382 383 384 385 386 387 388
            with control_flow.Switch() as switch:
                for i in range(len(boundaries)):
                    boundary_val = tensor.fill_constant(
                        shape=[1],
                        dtype='float32',
                        value=float(boundaries[i]),
                        force_cpu=True)
                    value_var = tensor.fill_constant(
                        shape=[1], dtype='float32', value=float(values[i]))
                    with switch.case(global_step < boundary_val):
                        tensor.assign(value_var, lr)
                last_value_var = tensor.fill_constant(
389 390
                    shape=[1],
                    dtype='float32',
391 392 393
                    value=float(values[len(values) - 1]))
                with switch.default():
                    tensor.assign(last_value_var, lr)
394

395
            return lr
W
Wu Yi 已提交
396 397


S
shippingwang 已提交
398 399 400 401
def cosine_decay(learning_rate, step_each_epoch, epochs):
    """
    Applies cosine decay to the learning rate.

S
shippingwang 已提交
402
    when training a model, it is often recommended to lower the learning rate as the
S
shippingwang 已提交
403 404
    training progresses. By using this function, the learning rate will be decayed by
    following cosine decay strategy.
S
shippingwang 已提交
405

406 407
    .. math::

X
xsrobin 已提交
408 409
        decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1)

S
shippingwang 已提交
410 411 412 413 414
    Args:
        learning_rate(Variable|float): The initial learning rate.
        step_each_epoch(int): the number of steps in an epoch.
        epochs(int): the number of epochs.

415
    Returns:
X
xsrobin 已提交
416
        Variable: The decayed learning rate.
S
shippingwang 已提交
417

418
    Examples:
X
xsrobin 已提交
419
        .. code-block:: python
S
shippingwang 已提交
420

X
xsrobin 已提交
421 422 423 424
            import paddle.fluid as fluid
            base_lr = 0.1
            lr = fluid.layers.cosine_decay(
            learning_rate = base_lr, step_each_epoch=10000, epochs=120)
S
shippingwang 已提交
425
    """
426

S
shippingwang 已提交
427
    with default_main_program()._lr_schedule_guard():
M
minqiyang 已提交
428 429 430 431 432 433
        if imperative_base.enabled():
            decay = imperate_lr.CosineDecay(learning_rate, step_each_epoch,
                                            epochs)
            return decay
        else:
            global_step = _decay_step_counter()
S
shippingwang 已提交
434

M
minqiyang 已提交
435 436 437 438
            cur_epoch = ops.floor(global_step / step_each_epoch)
            decayed_lr = learning_rate * 0.5 * (
                ops.cos(cur_epoch * math.pi / epochs) + 1)
            return decayed_lr
S
shippingwang 已提交
439 440


441 442 443 444 445 446 447 448 449 450 451 452 453 454
def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
    """
    Applies linear learning rate warmup before the normal learning rate
    scheduling.

    .. code-block:: python

     if global_step < warmup_steps:
         linear_step = end_lr - start_lr
         lr = start_lr + linear_step * (global_step / warmup_steps)

    Args:
        learning_rate (float | Variable): A float value or Variable.
        warmup_steps (int): The warmup steps.
Q
qingqing01 已提交
455 456
        start_lr (float): The start learning rate of warmup.
        end_lr (float): The end learning rate of warmup.
457 458 459 460 461 462 463

    Returns:
        The decayed learning rate in warmup period.

    Examples:
        .. code-block:: python

464
            import paddle.fluid as fluid
465 466 467 468 469 470 471 472 473 474
            boundaries = [100, 200]
            lr_steps = [0.1, 0.01, 0.001]
            warmup_steps = 50 
            start_lr = 1. / 3. 
            end_lr = 0.1
            decayed_lr = fluid.layers.linear_lr_warmup(
                fluid.layers.piecewise_decay(boundaries, lr_steps),
                warmup_steps, start_lr, end_lr)

    """
Q
qingqing01 已提交
475 476 477 478 479
    dtype = 'float32'
    if isinstance(learning_rate, Variable):
        dtype = learning_rate.dtype

    linear_step = float(end_lr) - float(start_lr)
480 481 482 483
    with default_main_program()._lr_schedule_guard():
        lr = tensor.create_global_var(
            shape=[1],
            value=0.0,
Q
qingqing01 已提交
484
            dtype=dtype,
485 486 487 488 489 490 491 492 493 494 495
            persistable=True,
            name="learning_rate_warmup")

        global_step = _decay_step_counter()

        with control_flow.Switch() as switch:
            with switch.case(global_step < warmup_steps):
                decayed_lr = start_lr + linear_step * (global_step /
                                                       float(warmup_steps))
                tensor.assign(decayed_lr, lr)
            with switch.default():
Q
qingqing01 已提交
496 497 498
                if not isinstance(learning_rate, Variable):
                    learning_rate = tensor.fill_constant(
                        shape=[1], dtype=dtype, value=float(learning_rate))
499 500
                tensor.assign(learning_rate, lr)
    return lr