learning_rate_scheduler.py 16.0 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

M
minqiyang 已提交
17 18
import math

M
minqiyang 已提交
19 20
from .. import unique_name

21
__all__ = [
M
minqiyang 已提交
22
    'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay',
M
minqiyang 已提交
23
    'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay'
24
]
M
minqiyang 已提交
25 26 27 28 29


class LearningRateDecay(object):
    """
    Base class of learning rate decay
30 31 32 33
    
    Define the common interface of an LearningRateDecay.
    User should not use this class directly,
    but need to use one of it's implementation.
M
minqiyang 已提交
34 35
    """

M
minqiyang 已提交
36 37 38
    def __init__(self, begin=0, step=1, dtype='float32'):
        self.step_num = begin
        self.step_size = step
M
minqiyang 已提交
39 40 41 42 43
        self.dtype = dtype

    def __call__(self):
        lr = self.step()
        if isinstance(lr, float):
M
minqiyang 已提交
44
            lr = self.create_lr_var(lr)
M
minqiyang 已提交
45
        self.step_num += self.step_size
M
minqiyang 已提交
46 47
        return lr

M
minqiyang 已提交
48
    def create_lr_var(self, lr):
49 50 51 52 53 54 55 56
        """
        convert lr from float to variable

        Args: 
            lr: learning rate
        Returns:
            learning rate variable
        """
M
minqiyang 已提交
57
        from .. import layers
M
minqiyang 已提交
58 59 60 61 62 63
        lr = layers.create_global_var(
            name=unique_name.generate("learning_rate"),
            shape=[1],
            value=float(lr),
            dtype=self.dtype,
            persistable=True)
M
minqiyang 已提交
64
        return lr
M
minqiyang 已提交
65 66 67 68 69

    def step(self):
        raise NotImplementedError()


M
minqiyang 已提交
70
class PiecewiseDecay(LearningRateDecay):
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    """
    piecewise decay scheduler

    The algorithm can be described as the code below.

    .. code-block:: text

      boundaries = [10000, 20000]
      values = [1.0, 0.5, 0.1]
      if step < 10000:
          learning_rate = 1.0
      elif 10000 <= step < 20000:
          learning_rate = 0.5
      else:
          learning_rate = 0.1
    Args:
        boundaries: A list of steps numbers.
        values: A list of learning rate values that will be picked during
            different step boundaries.
        begin: The begin step to initilize the self.step_num
        step: The step_size using when calculate the new step_num (Defalult is 1)
        dtype: The dtype used to create the learning rate variable

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
          with fluid.dygraph.guard():
              optimizer = fluid.optimizer.SGD(
                 learning_rate=fluid.dygraph.PiecewiseDecay(boundaries, values, 0) )
    """

M
minqiyang 已提交
105 106
    def __init__(self, boundaries, values, begin, step=1, dtype='float32'):
        super(PiecewiseDecay, self).__init__(begin, step, dtype)
M
minqiyang 已提交
107 108 109 110 111
        self.boundaries = boundaries
        self.values = values

        self.vars = []
        for value in values:
112
            self.vars.append(value)
M
minqiyang 已提交
113 114

    def step(self):
M
minqiyang 已提交
115 116
        for i in range(len(self.boundaries)):
            if self.step_num < self.boundaries[i]:
M
minqiyang 已提交
117
                return self.vars[i]
118
        return self.create_lr_var(self.vars[len(self.values) - 1])
119 120 121


class NaturalExpDecay(LearningRateDecay):
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    """
    Applies natural exponential decay to the initial learning rate.
    
    .. code-block:: python

        if not staircase:
            decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
        else:
            decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))

    Args:
        learning_rate: A scalar float32 value or a Variable. This
          will be the initial learning rate during training
        decay_steps: A Python `int32` number.
        decay_rate: A Python `float` number.
        staircase: Boolean. If set true, decay the learning rate every decay_steps.
        begin: A Python 'int32' number, the begin step (Default is 0)
        step: A Python 'int32' number, the step size (Default is 1)
        dtype: A Python 'str', the dtype used to create learning rate variable (Default is 'float32')

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          base_lr = 0.1
          with fluid.dygraph.guard():
              sgd_optimizer = fluid.optimizer.SGD(
        	      learning_rate=fluid.dygraph.NaturalExpDecay(
	    	            learning_rate=base_lr,
        		    decay_steps=10000,
		            decay_rate=0.5,
		            staircase=True))

    """

157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(NaturalExpDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)
        decayed_lr = self.learning_rate * layers.exp(-1 * self.decay_rate *
                                                     div_res)

        return decayed_lr


class ExponentialDecay(LearningRateDecay):
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
    """
    Applies exponential decay to the learning rate.

    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
    'decay_rate' every 'decay_steps' steps.
    
    .. code-block:: python

        if staircase == True:
            decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps)
        else:
            decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)

    Args:
        learning_rate(Variable|float): The initial learning rate.
        decay_steps(int): See the decay computation above.
        decay_rate(float): The decay rate. See the decay computation above.
        staircase(Boolean): If True, decay the learning rate at discrete intervals.
                            Default: False
        begin(int): The begin step (default is 0)
        step(int): The step size (default is 1)
        dtype(str): The dtype used to create learning rate (default is 'float32')

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          base_lr = 0.1
          with fluid.dygraph.guard():
              sgd_optimizer = fluid.optimizer.SGD(
    	            learning_rate=fluid.dygraph.ExponentialDecay(
		        learning_rate=base_lr,
    		        decay_steps=10000,
		        decay_rate=0.5,
		        staircase=True))

    """

222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(ExponentialDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)

        decayed_lr = self.learning_rate * (self.decay_rate**div_res)

        return decayed_lr


class InverseTimeDecay(LearningRateDecay):
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
    """
    Applies inverse time decay to the initial learning rate.

    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, an inverse decay function will be
    applied to the initial learning rate.

    >>> if staircase == True:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
    >>> else:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)

    Args:
        learning_rate(Variable|float): The initial learning rate.
        decay_steps(int): See the decay computation above.
        decay_rate(float): The decay rate. See the decay computation above.
        staircase(Boolean): If True, decay the learning rate at discrete intervals.
                            Default: False
        begin(int): The begin step (default is 0)
        step(int): The step size (default is 1)
        dtype(str): The dtype used to create learning rate (default is 'float32')

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          base_lr = 0.1
          with fluid.dygraph.guard():
              sgd_optimizer = fluid.optimizer.SGD(
	          learning_rate=fluid.dygraph.InverseTimeDecay(
		        learning_rate=base_lr,
		        decay_steps=10000,
		        decay_rate=0.5,
		        staircase=True))

    """

285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(InverseTimeDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)

        decayed_lr = self.learning_rate / (1 + self.decay_rate * div_res)

        return decayed_lr


class PolynomialDecay(LearningRateDecay):
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
    """
    Applies polynomial decay to the initial learning rate.

    .. code-block:: text

     if cycle:
       decay_steps = decay_steps * ceil(global_step / decay_steps)
     else:
       global_step = min(global_step, decay_steps)
       decayed_learning_rate = (learning_rate - end_learning_rate) *
            (1 - global_step / decay_steps) ^ power + end_learning_rate

    Args:
        learning_rate(Variable|float32): A scalar float32 value or a Variable. This
          will be the initial learning rate during training.
        decay_steps(int32): A Python `int32` number.
        end_learning_rate(float): A Python `float` number.
        power(float): A Python `float` number.
        cycle(bool): If set true, decay the learning rate every decay_steps.
        begin(int): The begin step (default is 0)
        step(int): The step size (default is 1)
        dtype(str): The dtype used to create learning rate (default is 'float32')

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          with fluid.dygraph.guard():
              optimizer  = fluid.optimizer.SGD(
                  learning_rate = fluid.dygraph.PolynomialDecay(
                  start_lr, total_step, end_lr, power=1.0) )

    """

348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 end_learning_rate=0.0001,
                 power=1.0,
                 cycle=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(PolynomialDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.end_learning_rate = end_learning_rate
        self.power = power
        self.cycle = cycle

    def step(self):
        from .. import layers
M
minqiyang 已提交
366 367
        tmp_step_num = self.step_num
        tmp_decay_steps = self.decay_steps
368 369
        if self.cycle:
            div_res = layers.ceil(
M
minqiyang 已提交
370
                self.create_lr_var(tmp_step_num / float(self.decay_steps)))
371

M
minqiyang 已提交
372 373
            if tmp_step_num == 0:
                div_res = self.create_lr_var(1.0)
M
minqiyang 已提交
374
            tmp_decay_steps = self.decay_steps * div_res
375
        else:
M
minqiyang 已提交
376 377 378 379 380 381 382
            tmp_step_num = self.create_lr_var(tmp_step_num
                                              if tmp_step_num < self.decay_steps
                                              else self.decay_steps)

        decayed_lr = (self.learning_rate - self.end_learning_rate) * \
            ((1 - tmp_step_num / tmp_decay_steps) ** self.power) + self.end_learning_rate
        return decayed_lr
383

M
minqiyang 已提交
384 385

class CosineDecay(LearningRateDecay):
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
    """
    Applies cosine decay to the learning rate.

    when training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
    following cosine decay strategy.

    .. math::

	decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1)
    
    Args:
        learning_rate(Variable|float): The initial learning rate.
        step_each_epoch(int): the number of steps in an epoch.
        epochs(int): the number of epochs.
        begin(int): The begin step (default is 0).
        step(int): The step size (default is 1).
        dtype(str): The dtype used to create learning rate (default is 'float32').

    Examples:
	.. code-block:: python

  	    base_lr = 0.1
            with fluid.dygraph.guard():
                optimizer  = fluid.optimizer.SGD(
        	    learning_rate = fluid.dygraph.CosineDecay(
	                    base_lr, 10000, 120) )
    """

M
minqiyang 已提交
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
    def __init__(self,
                 learning_rate,
                 step_each_epoch,
                 epochs,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(CosineDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.step_each_epoch = step_each_epoch
        self.epochs = epochs

    def step(self):
        from .. import layers
        cur_epoch = layers.floor(
            self.create_lr_var(self.step_num / self.step_each_epoch))
        decayed_lr = self.learning_rate * 0.5 * (
            layers.cos(cur_epoch * math.pi / self.epochs) + 1)
        return decayed_lr


class NoamDecay(LearningRateDecay):
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
    """
    Noam decay method. The numpy implementation of noam decay as follows.

    .. code-block:: python
      
      import numpy as np
      # set hyper parameters
      d_model = 2
      current_steps = 20
      warmup_steps = 200
      # compute
      lr_value = np.power(d_model, -0.5) * np.min([
                              np.power(current_steps, -0.5),
                              np.power(warmup_steps, -1.5) * current_steps])

    Please reference `attention is all you need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.

    Args:
        d_model(Variable): The dimensionality of input and output of model.

        warmup_steps(Variable): A super parameter.
        begin(int): The begin step (default is 0)
        step(int): The step size (default is 1)
        dtype(str): The dtype used to create learning rate (default is 'float32')

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          warmup_steps = 100
          learning_rate = 0.01
          with fluid.dygraph.guard():
              optimizer  = fluid.optimizer.SGD(
                  learning_rate = fluid.dygraph.NoamDecay(
                         1/(warmup_steps *(learning_rate ** 2)),
                         warmup_steps) )
    """

M
minqiyang 已提交
476 477 478 479 480 481 482
    def __init__(self, d_model, warmup_steps, begin=1, step=1, dtype='float32'):
        super(NoamDecay, self).__init__(begin, step, dtype)
        self.d_model = d_model
        self.warmup_steps = warmup_steps

    def step(self):
        from .. import layers
M
minqiyang 已提交
483 484 485
        a = self.create_lr_var(self.step_num**-0.5)
        b = self.create_lr_var((self.warmup_steps**-1.5) * self.step_num)
        lr_value = (self.d_model**-0.5) * layers.elementwise_min(a, b)
M
minqiyang 已提交
486
        return lr_value