learning_rate_scheduler.py 44.9 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

M
minqiyang 已提交
17
import math
18
import warnings
M
minqiyang 已提交
19

M
minqiyang 已提交
20
from .. import unique_name
21 22
from ..framework import Variable
from ..data_feeder import check_type
M
minqiyang 已提交
23

24
__all__ = [
M
minqiyang 已提交
25
    'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay',
26
    'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay', 'LinearLrWarmup',
27
    'ReduceLROnPlateau', 'StepDecay', 'MultiStepDecay', 'LambdaDecay'
28
]
M
minqiyang 已提交
29 30 31 32 33


class LearningRateDecay(object):
    """
    Base class of learning rate decay
34 35 36 37
    
    Define the common interface of an LearningRateDecay.
    User should not use this class directly,
    but need to use one of it's implementation.
M
minqiyang 已提交
38 39
    """

M
minqiyang 已提交
40 41 42
    def __init__(self, begin=0, step=1, dtype='float32'):
        self.step_num = begin
        self.step_size = step
M
minqiyang 已提交
43 44 45 46 47
        self.dtype = dtype

    def __call__(self):
        lr = self.step()
        if isinstance(lr, float):
M
minqiyang 已提交
48
            lr = self.create_lr_var(lr)
M
minqiyang 已提交
49
        self.step_num += self.step_size
M
minqiyang 已提交
50 51
        return lr

M
minqiyang 已提交
52
    def create_lr_var(self, lr):
53 54 55 56 57 58 59 60
        """
        convert lr from float to variable

        Args: 
            lr: learning rate
        Returns:
            learning rate variable
        """
M
minqiyang 已提交
61
        from .. import layers
M
minqiyang 已提交
62 63 64 65 66
        lr = layers.create_global_var(
            name=unique_name.generate("learning_rate"),
            shape=[1],
            value=float(lr),
            dtype=self.dtype,
Z
Zeng Jinle 已提交
67
            persistable=False)
M
minqiyang 已提交
68
        return lr
M
minqiyang 已提交
69

70 71
    # Note: If you want to change what optimizer.state_dict stores, just overwrite this functions, 
    # "self.step_num" will be stored by default.
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    def state_dict(self):
        """
        Returns the state of the scheduler as a :class:`dict`.

        It is a subset of self.__dict__ .
        """
        self._state_keys()
        state_dict = {}
        for key in self.keys:
            if key not in self.__dict__:
                continue
            value = self.__dict__[key]
            if isinstance(value, Variable):
                assert value.shape == [
                    1
                ], "shape of Variable in state_dict must be [1] {}".format(
                    value.shape)
                value = value.numpy()[0]
            state_dict[key] = value

        return state_dict

    def _state_keys(self):
        """
        set the keys in self.__dict__ that are needed to be saved.
        """
        self.keys = ['step_num']

    def set_dict(self, state_dict):
        """
        Loads the schedulers state.
        """
        self._state_keys()
        for key in self.keys:
            if key in state_dict:
                self.__dict__[key] = state_dict[key]
            else:
                raise RuntimeError(
                    "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".
                    format(key))
        if len(state_dict) > len(self.keys):
            warnings.warn(
                "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
            )

M
minqiyang 已提交
117 118 119 120
    def step(self):
        raise NotImplementedError()


M
minqiyang 已提交
121
class PiecewiseDecay(LearningRateDecay):
122
    """
123 124
    :api_attr: imperative
    
D
DuYao 已提交
125
    Piecewise decay scheduler.
126 127 128 129 130

    The algorithm can be described as the code below.

    .. code-block:: text

D
DuYao 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143
        boundaries = [10000, 20000]
        values = [1.0, 0.5, 0.1]
        if global_step < 10000:
            learning_rate = 1.0
        elif 10000 <= global_step < 20000:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Parameters:
        boundaries(list): A list of steps numbers. The type of element in the list is python int. 
        values(list): A list of learning rate values that will be picked during
            different step boundaries. The type of element in the list is python float.
T
tianshuo78520a 已提交
144
        begin(int): The begin step to initialize the global_step in the description above.
D
DuYao 已提交
145
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
146
            The default value is 1.
D
DuYao 已提交
147 148
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
149

150
    Returns:
D
DuYao 已提交
151
        None.
152

153 154 155 156 157 158 159
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
          with fluid.dygraph.guard():
160
              emb = fluid.dygraph.Embedding( [10, 10] )
161
              optimizer = fluid.optimizer.SGD(
162 163
                 learning_rate=fluid.dygraph.PiecewiseDecay(boundaries, values, 0),
                 parameter_list = emb.parameters() )
164 165
    """

M
minqiyang 已提交
166 167
    def __init__(self, boundaries, values, begin, step=1, dtype='float32'):
        super(PiecewiseDecay, self).__init__(begin, step, dtype)
M
minqiyang 已提交
168 169 170 171 172
        self.boundaries = boundaries
        self.values = values

        self.vars = []
        for value in values:
173
            self.vars.append(value)
M
minqiyang 已提交
174 175

    def step(self):
M
minqiyang 已提交
176 177
        for i in range(len(self.boundaries)):
            if self.step_num < self.boundaries[i]:
M
minqiyang 已提交
178
                return self.vars[i]
179
        return self.create_lr_var(self.vars[len(self.values) - 1])
180 181 182


class NaturalExpDecay(LearningRateDecay):
183
    """
184 185
    :api_attr: imperative

186 187
    Applies natural exponential decay to the initial learning rate.
    
D
DuYao 已提交
188
    The algorithm can be described as following.
189

D
DuYao 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
    .. math::

        decayed\_learning\_rate = learning\_rate * e^{y} 

    If staircase is set to False, then:

    .. math::

        y = - decay\_rate * \\frac{global\_step}{decay\_steps}

    If staircase is set to True, then:

    .. math::

        y = - decay\_rate * math.floor(\\frac{global\_step}{decay\_steps}) 

    Parameters:
        learning_rate(Variable|float): The initial learning rate. If the type 
            is Variable, it's a tensor with shape [1], the data type can be  
            float32 or float64. It also can be set to python int number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        decay_rate(int): The decay rate.
        staircase(bool, optional): If set to True, decay the learning rate at discrete intervals. The 
            default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
216
            The default value is 1.
D
DuYao 已提交
217 218
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
219

220
    Returns:
D
DuYao 已提交
221
        None.
222

223 224 225
    Examples:
        .. code-block:: python

226 227 228 229 230 231 232 233 234 235 236
            import paddle.fluid as fluid
            base_lr = 0.1
            with fluid.dygraph.guard():
                emb = fluid.dygraph.Embedding([10, 10])
                sgd_optimizer = fluid.optimizer.SGD(
                        learning_rate=fluid.dygraph.NaturalExpDecay(
                            learning_rate=base_lr,
                            decay_steps=10000,
                            decay_rate=0.5,
                            staircase=True),
                        parameter_list=emb.parameters())
237 238 239

    """

240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(NaturalExpDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)
        decayed_lr = self.learning_rate * layers.exp(-1 * self.decay_rate *
                                                     div_res)

        return decayed_lr


class ExponentialDecay(LearningRateDecay):
266
    """
267 268
    :api_attr: imperative

269 270
    Applies exponential decay to the learning rate.

D
DuYao 已提交
271
    The algorithm can be described as following.
272
    
D
DuYao 已提交
273
    .. math::
274

D
DuYao 已提交
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
        decayed\_learning\_rate = learning\_rate * decay\_rate ^ y 

    If staircase is set to False, then:

    .. math::

        y = \\frac{global\_step}{decay\_steps} 

    If staircase is set to True, then:

    .. math::

        y = math.floor(\\frac{global\_step}{decay\_steps})


    Parameters:
        learning_rate(Variable|float): The initial learning rate. If the type 
            is Variable, it's a tensor with shape [1], the data type can be  
            float32 or float64. It also can be set to python int number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        decay_rate(float): The decay rate.
        staircase(bool, optional): If set to True, decay the learning rate at discrete intervals. The 
            default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
300
            The default value is 1.
D
DuYao 已提交
301 302
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
303

304
    Returns:
D
DuYao 已提交
305
        None.
306

307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          base_lr = 0.1
          with fluid.dygraph.guard():
              sgd_optimizer = fluid.optimizer.SGD(
    	            learning_rate=fluid.dygraph.ExponentialDecay(
		        learning_rate=base_lr,
    		        decay_steps=10000,
		        decay_rate=0.5,
		        staircase=True))

    """

322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(ExponentialDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)

        decayed_lr = self.learning_rate * (self.decay_rate**div_res)

        return decayed_lr


class InverseTimeDecay(LearningRateDecay):
348
    """
349 350
    :api_attr: imperative

351 352
    Applies inverse time decay to the initial learning rate.

D
DuYao 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
    The algorithm can be described as following.
    If staircase is set to False, then:

    .. math::

        decayed\_learning\_rate = \\frac{learning\_rate}{1 + decay\_rate * \\frac{global\_step}{decay\_step}}  

    If staircase is set to True, then:

    .. math::

        decayed\_learning\_rate = \\frac{learning\_rate}{1 + decay\_rate * math.floor(\\frac{global\_step}{decay\_step})}

    Parameters:
        learning_rate(Variable|float): The initial learning rate. If the type 
            is Variable, it's a tensor with shape [1], the data type can be  
            float32 or float64. It also can be set to python int number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        decay_rate(float): The decay rate.
        staircase(bool, optional): If set to True, decay the learning rate at discrete intervals. The 
            default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
376
            The default value is 1.
D
DuYao 已提交
377 378
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be 
            'float32', 'float64'. The default value is 'float32'.
379

380
    Returns:
D
DuYao 已提交
381
        None.
382

383 384 385 386 387 388
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          base_lr = 0.1
          with fluid.dygraph.guard():
389
              emb = fluid.dygraph.Embedding([10, 10])
390 391 392 393 394
              sgd_optimizer = fluid.optimizer.SGD(
	          learning_rate=fluid.dygraph.InverseTimeDecay(
		        learning_rate=base_lr,
		        decay_steps=10000,
		        decay_rate=0.5,
395 396
		        staircase=True),
                  parameter_list = emb.parameters())
397 398 399

    """

400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(InverseTimeDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)

        decayed_lr = self.learning_rate / (1 + self.decay_rate * div_res)

        return decayed_lr


class PolynomialDecay(LearningRateDecay):
426
    """
427 428
    :api_attr: imperative

429 430
    Applies polynomial decay to the initial learning rate.

D
DuYao 已提交
431 432 433 434 435 436 437
    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

        decay\_steps & = decay\_steps * math.ceil(\\frac{global\_step}{decay\_steps}) 
438

D
DuYao 已提交
439 440 441 442 443 444 445 446 447 448 449 450 451 452
        decayed\_learning\_rate & = (learning\_rate-end\_learning\_rate)*(1-\\frac{global\_step}{decay\_steps})^{power}+end\_learning\_rate

    If cycle is set to False, then:

    .. math::

        global\_step & = min(global\_step, decay\_steps) 

        decayed\_learning\_rate & = (learning\_rate-end\_learning\_rate)*(1-\\frac{global\_step}{decay\_steps})^{power}+end\_learning\_rate

    Parameters:
        learning_rate(Variable|float): The initial learning rate. If the type 
            is Variable, it's a tensor with shape [1], the data type can be  
            float32 or float64. It also can be set to python int number.
453
        decay_steps(int): The decay step size. It determines the decay cycle.
D
DuYao 已提交
454 455 456 457 458
        end_learning_rate(float, optional): The minimum final learning rate. The default value is 0.0001.
        power(float, optional): Power of polynomial. The default value is 1.0.
        cycle(bool, optional): If set true, decay the learning rate every decay_steps. The default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
459
            The default value is 1.
D
DuYao 已提交
460 461
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
462

463
    Returns:
D
DuYao 已提交
464
        None.
465

466 467 468 469 470 471 472 473
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          with fluid.dygraph.guard():
474
              emb = fluid.dygraph.Embedding( [10, 10])
475 476
              optimizer  = fluid.optimizer.SGD(
                  learning_rate = fluid.dygraph.PolynomialDecay(
477 478
                  start_lr, total_step, end_lr, power=1.0),
                  parameter_list = emb.parameters())
479 480 481

    """

482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 end_learning_rate=0.0001,
                 power=1.0,
                 cycle=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(PolynomialDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.end_learning_rate = end_learning_rate
        self.power = power
        self.cycle = cycle

    def step(self):
        from .. import layers
M
minqiyang 已提交
500 501
        tmp_step_num = self.step_num
        tmp_decay_steps = self.decay_steps
502 503
        if self.cycle:
            div_res = layers.ceil(
M
minqiyang 已提交
504
                self.create_lr_var(tmp_step_num / float(self.decay_steps)))
505

M
minqiyang 已提交
506 507
            if tmp_step_num == 0:
                div_res = self.create_lr_var(1.0)
M
minqiyang 已提交
508
            tmp_decay_steps = self.decay_steps * div_res
509
        else:
M
minqiyang 已提交
510 511 512 513 514 515 516
            tmp_step_num = self.create_lr_var(tmp_step_num
                                              if tmp_step_num < self.decay_steps
                                              else self.decay_steps)

        decayed_lr = (self.learning_rate - self.end_learning_rate) * \
            ((1 - tmp_step_num / tmp_decay_steps) ** self.power) + self.end_learning_rate
        return decayed_lr
517

M
minqiyang 已提交
518 519

class CosineDecay(LearningRateDecay):
520
    """
521 522
    :api_attr: imperative

523 524
    Applies cosine decay to the learning rate.

D
DuYao 已提交
525
    The algorithm can be described as following.
526 527 528

    .. math::

D
DuYao 已提交
529
        decayed\_learning\_rate = learning\_rate * 0.5 * (math.cos(global\_step * \\frac{math.pi}{step\_each\_epoch} ) + 1)
530
    
D
DuYao 已提交
531 532 533 534 535 536 537 538
    Parameters:
        learning_rate(Variable|float): The initial learning rate. If the type 
            is Variable, it's a tensor with shape [1], the data type can be  
            float32 or float64. It also can be set to python int number.
        step_each_epoch(int): The number of steps in an epoch.
        epochs(int): The number of epochs.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
539
            The default value is 1.
D
DuYao 已提交
540 541
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
542

543
    Returns:
D
DuYao 已提交
544
        None.
545

546 547 548 549 550 551 552 553 554 555
    Examples:
	.. code-block:: python

  	    base_lr = 0.1
            with fluid.dygraph.guard():
                optimizer  = fluid.optimizer.SGD(
        	    learning_rate = fluid.dygraph.CosineDecay(
	                    base_lr, 10000, 120) )
    """

M
minqiyang 已提交
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
    def __init__(self,
                 learning_rate,
                 step_each_epoch,
                 epochs,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(CosineDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.step_each_epoch = step_each_epoch
        self.epochs = epochs

    def step(self):
        from .. import layers
        cur_epoch = layers.floor(
            self.create_lr_var(self.step_num / self.step_each_epoch))
        decayed_lr = self.learning_rate * 0.5 * (
            layers.cos(cur_epoch * math.pi / self.epochs) + 1)
        return decayed_lr


class NoamDecay(LearningRateDecay):
578
    """
579 580
    :api_attr: imperative

D
DuYao 已提交
581 582 583 584 585 586
    Applies Noam decay to the initial learning rate. 

    The algorithm can be described as following.

    .. math::

587
        decayed\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(global\_step^{-0.5}, global\_step * warmup\_steps^{-1.5})
D
DuYao 已提交
588 589 590 591 592 593 594 595 596 597

    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_ 

    Parameters:
        d$_{model}$(Variable|int): The dimensionality of input and output feature vector of model. If type is Variable, 
            it's a tensor with shape [1] and the data type can be int32 or int64. The type can also be python int.
        warmup_steps(Variable|int): The number of warmup steps. A super parameter. If type is Variable, 
            it's a tensor with shape [1] and the data type can be int32 or int64. The type can also be python int.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
598
            The default value is 1.
D
DuYao 已提交
599 600
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
601 602 603
        learning_rate(Variable|float|int): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
            float32 or float64. It also can be set to python int number. Default 1.0
604

605
    Returns:
D
DuYao 已提交
606
        None.
607

608 609 610 611 612 613 614
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          warmup_steps = 100
          learning_rate = 0.01
          with fluid.dygraph.guard():
615
              emb = fluid.dygraph.Embedding([10, 10])
616 617 618
              optimizer  = fluid.optimizer.SGD(
                  learning_rate = fluid.dygraph.NoamDecay(
                         1/(warmup_steps *(learning_rate ** 2)),
619 620
                         warmup_steps),
                  parameter_list = emb.parameters())
621 622
    """

623 624 625 626 627 628 629
    def __init__(self,
                 d_model,
                 warmup_steps,
                 begin=1,
                 step=1,
                 dtype='float32',
                 learning_rate=1.0):
M
minqiyang 已提交
630
        super(NoamDecay, self).__init__(begin, step, dtype)
631
        self.learning_rate = learning_rate
M
minqiyang 已提交
632 633 634 635 636
        self.d_model = d_model
        self.warmup_steps = warmup_steps

    def step(self):
        from .. import layers
M
minqiyang 已提交
637 638
        a = self.create_lr_var(self.step_num**-0.5)
        b = self.create_lr_var((self.warmup_steps**-1.5) * self.step_num)
639 640
        lr_value = self.learning_rate * (self.d_model
                                         **-0.5) * layers.elementwise_min(a, b)
M
minqiyang 已提交
641
        return lr_value
H
hong 已提交
642 643 644 645


class LinearLrWarmup(LearningRateDecay):
    """
646 647
    :api_attr: imperative

H
hong 已提交
648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674
    This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
    
    When global_step < warmup_steps, learning rate is updated as:
    
    .. code-block:: text
    
            linear_step = end_lr - start_lr
            lr = start_lr + linear_step * (global_step / warmup_steps)
    
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
    
    When global_step >= warmup_steps, learning rate is updated as:
    
    .. code-block:: text
    
            lr = learning_rate
    
    where lr is the learning_rate after warm-up.
    
    Args:
        learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
        warmup_steps (int): Steps for warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
675
            The default value is 1.
H
hong 已提交
676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
    
    Returns:
        Variable: Warm-up learning rate with the same data type as learning_rate.
    
    
    Examples:
    
    .. code-block:: python
    
        import paddle.fluid as fluid
    
        learning_rate = 0.1 
        warmup_steps = 50
691
        start_lr = 0
H
hong 已提交
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
        end_lr = 0.1

        with fluid.dygraph.guard(): 
            lr_decay = fluid.dygraph.LinearLrWarmup( learning_rate, warmup_steps, start_lr, end_lr)
    
       
    """

    def __init__(self,
                 learning_rate,
                 warmup_steps,
                 start_lr,
                 end_lr,
                 begin=1,
                 step=1,
                 dtype='float32'):
        super(LinearLrWarmup, self).__init__(begin, step, dtype)
        type_check = isinstance(learning_rate, float) or isinstance(
            learning_rate, int) or isinstance(learning_rate, LearningRateDecay)
        if not type_check:
            raise TypeError(
                "the type of learning_rate should be [int, float or LearningRateDecay], the current type is {}".
                format(learning_rate))
        self.learning_rate = learning_rate
        self.warmup_steps = warmup_steps
717
        self.start_lr = start_lr
Z
Zeng Jinle 已提交
718 719
        assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
            end_lr, start_lr)
H
hong 已提交
720 721 722 723 724 725 726 727 728 729
        self.lr_ratio_before_warmup = (
            float(end_lr) - float(start_lr)) / float(warmup_steps)

    def step(self):
        base_lr = self.learning_rate
        if isinstance(self.learning_rate, LearningRateDecay):
            base_lr = base_lr()

        from .. import layers
        if self.step_num < self.warmup_steps:
730
            return self.lr_ratio_before_warmup * self.step_num + self.start_lr
H
hong 已提交
731 732
        else:
            return base_lr
733 734 735 736


class ReduceLROnPlateau(LearningRateDecay):
    """
737 738
    :api_attr: imperative

739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
    Reduce learning rate when ``loss`` has stopped descending. Models often benefit from reducing the learning rate 
    by 2 to 10 times once model performance has no longer improvement.

    The ``loss`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``loss`` 
    stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * decay_rate`` . 
    (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``loss`` stop ascending for a ``patience`` number 
    of epochs, the learning rate will be reduced.)

    In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming normal operation.

    Args:
        learning_rate (Variable|float|int): The initial learning rate. It can be set to python float or int number.
            If the type is Variable, it should be 1-D Tensor with shape [1], the data type can be 'float32' or 'float64'.
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the 
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` ,  the learning 
            rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
        decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` . 
            It should be less than 1.0. Default: 0.1.
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced. 
            Default: 10.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` . 
            This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum 
            change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
        eps (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is
            ignored. Default: 1e-8.
        dtype (str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. Default: 'float32'. 
    
    Returns:
        Reduced learning rate.

    Examples:
    
    .. code-block:: python

        import paddle.fluid as fluid
        import numpy as np

        with fluid.dygraph.guard():
            x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
            linear = fluid.dygraph.Linear(10, 10)
            input = fluid.dygraph.to_variable(x)

            reduce_lr = fluid.dygraph.ReduceLROnPlateau(
                                    learning_rate = 1.0,
                                    decay_rate = 0.5,
                                    patience = 5,
                                    verbose = True, 
                                    cooldown = 3)
            adam = fluid.optimizer.Adam(
                learning_rate = reduce_lr,
                parameter_list = linear.parameters())

            for epoch in range(10):
                total_loss = 0
                for bath_id in range(5):
                    out = linear(input)
                    loss = fluid.layers.reduce_mean(out)
                    total_loss += loss
                    adam.minimize(loss)
                
                avg_loss = total_loss/5

                # adjust learning rate according to avg_loss
                reduce_lr.step(avg_loss)
                lr = adam.current_step_lr()
                print("current avg_loss is %s, current lr is %s" % (avg_loss.numpy()[0], lr))

    """

    def __init__(self,
                 learning_rate,
                 mode='min',
                 decay_rate=0.1,
                 patience=10,
                 verbose=False,
                 threshold=1e-4,
                 threshold_mode='rel',
                 cooldown=0,
                 min_lr=0,
                 eps=1e-8,
                 dtype='float32'):
        super(ReduceLROnPlateau, self).__init__(dtype=dtype)
        mode = mode.lower()
        if mode not in ['min', 'max']:
            raise ValueError('mode ' + mode + ' is unknown!')
        self.mode = mode

        if decay_rate >= 1.0:
            raise ValueError(
                'new_lr = origin_lr * decay_rate and decay_rate should be < 1.0.'
            )
836
        self.decay_rate = self.create_lr_var(decay_rate)
837 838 839 840 841 842 843 844

        threshold_mode = threshold_mode.lower()
        if threshold_mode not in ['rel', 'abs']:
            raise ValueError('threshold mode ' + threshold_mode +
                             ' is unknown!')
        self.threshold_mode = threshold_mode
        check_type(learning_rate, 'learning_rate', (float, int, Variable),
                   'ReduceLROnPlateau')
845 846 847 848
        if not isinstance(learning_rate, (float, int, Variable)):
            raise TypeError(
                "The type of 'learning_rate' in 'ReduceLROnPlateau' must be 'float, int, Variable', but received %s."
                % type(learning_rate))
849 850 851 852 853 854 855 856 857 858 859 860 861

        self.learning_rate = learning_rate
        self.verbose = verbose
        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = self.create_lr_var(min_lr)
        self.eps = eps

        self.cooldown_counter = 0
        self.best_loss = None
        self.num_bad_epochs = 0
862 863
        self.epoch_num = 0

864
    # "cooldown_counter / best_loss / num_bad_epochs / epoch_num / learning_rate" will be stored.
865 866 867 868 869
    def _state_keys(self):
        self.keys = [
            'cooldown_counter', 'best_loss', 'num_bad_epochs', 'epoch_num',
            'learning_rate'
        ]
870 871

    def __call__(self):
872 873
        if not isinstance(self.learning_rate, Variable):
            self.learning_rate = self.create_lr_var(self.learning_rate)
874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898
        return self.learning_rate

    def step(self, loss):
        """
        It should be invoked on each epoch. Update the learning rate in optimizer according to ``loss`` .  
        The new learning rate will take effect on next call to ``optimizer.minimize`` .

        Args:
            loss (Variable): A ``Variable`` that will be monitored to determine whether the learning rate will reduce. 
                If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. It should 
                be 1-D Tensor with shape [1]. 
                Specially, if ``mode`` has been set to ``'max'`` ,  the learning rate will reduce when it stops ascending.
        Returns:
            None
        
        Examples:
            Please refer to the example of current LearningRateDecay.
        """

        # loss must be 1-D Tensor with shape [1]
        check_type(loss, 'loss', Variable, 'ReduceLROnPlateau.step')
        assert len(loss.shape) == 1 and loss.shape[0] == 1, "the loss.shape " \
            "should be (1L,), but the current loss.shape is {}. Maybe that "  \
            "you should call fluid.layers.mean to process it first.".format(loss.shape)

899
        self.epoch_num += 1
900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916
        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
        else:
            if self.best_loss is None or self._is_better(loss, self.best_loss):
                self.best_loss = loss
                self.num_bad_epochs = 0
            else:
                self.num_bad_epochs += 1

            if self.num_bad_epochs > self.patience:
                from .. import layers
                self.cooldown_counter = self.cooldown
                self.num_bad_epochs = 0
                new_lr = layers.elementwise_max(self.learning_rate *
                                                self.decay_rate, self.min_lr)
                if self.learning_rate - new_lr > self.eps:
                    if self.verbose:
917 918 919
                        old_lr = self.learning_rate.numpy()[0] if isinstance(
                            self.learning_rate,
                            Variable) else self.learning_rate
920
                        print('Epoch {}: reducing learning rate from {} to {}.'.
921
                              format(self.epoch_num, old_lr, new_lr.numpy()[0]))
922 923 924 925 926 927 928 929 930 931 932 933 934 935
                    self.learning_rate = new_lr

    def _is_better(self, current, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            return current < best - best * self.threshold

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return current < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            return current > best + best * self.threshold

        else:
            return current > best + self.threshold
936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953


class _LearningRateEpochDecay(LearningRateDecay):
    """
    :api_attr: imperative

    Base class of learning rate decay, which is updated each epoch.
    
    Define the common interface of an _LearningRateEpochDecay.
    User should not use this class directly,
    but need to use one of it's implementation. And invoke method: `epoch()` each epoch.
    """

    def __init__(self, learning_rate, dtype=None):
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
                "The type of 'learning_rate' must be 'float, int', but received %s."
                % type(learning_rate))
954 955
        if learning_rate < 0:
            raise ValueError("Invalid learning rate: {}".format(learning_rate))
956 957 958 959

        self.base_lr = float(learning_rate)

        self.epoch_num = -1
960
        self.dtype = dtype
961 962 963 964 965 966
        if dtype is None:
            self.dtype = "float32"
        self.learning_rate = self.create_lr_var(self.base_lr)

        self.epoch()

967 968
    # For those subclass who overload _LearningRateEpochDecay, "self.epoch_num/learning_rate" will be stored by default.
    # you can change it for your subclass.
969 970 971
    def _state_keys(self):
        self.keys = ['epoch_num', 'learning_rate']

972 973 974 975
    def __call__(self):
        """ 
        Return last computed learning rate on current epoch.
        """
976 977
        if not isinstance(self.learning_rate, Variable):
            self.learning_rate = self.create_lr_var(self.learning_rate)
978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015
        return self.learning_rate

    def epoch(self, epoch=None):
        """
        compueted learning_rate and update it when invoked.
        """
        if epoch is None:
            self.epoch_num += 1
        else:
            self.epoch_num = epoch

        self.learning_rate = self.get_lr()

    def get_lr(self):
        raise NotImplementedError


class StepDecay(_LearningRateEpochDecay):
    """
    :api_attr: imperative

    Decays the learning rate of ``optimizer`` by ``decay_rate`` every ``step_size`` number of epoch.

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5
        step_size = 30
        decay_rate = 0.1

        learning_rate = 0.5     if epoch < 30
        learning_rate = 0.05    if 30 <= epoch < 60
        learning_rate = 0.005   if 60 <= epoch < 90
        ...

    Parameters:
        learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
1016
        step_size (int): Period of learning rate decay.
1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093
        decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` . 
            It should be less than 1.0. Default: 0.1.

    Returns:
        None.

    Examples:
        .. code-block:: python
            
            import paddle.fluid as fluid
            import numpy as np
            with fluid.dygraph.guard():
                x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
                linear = fluid.dygraph.Linear(10, 10)
                input = fluid.dygraph.to_variable(x)
                scheduler = fluid.dygraph.StepDecay(0.5, step_size=3)
                adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters())

                for epoch in range(9):
                    for batch_id in range(5):
                        out = linear(input)
                        loss = fluid.layers.reduce_mean(out)
                        adam.minimize(loss)  
                    scheduler.epoch()

                    print("epoch:{}, current lr is {}" .format(epoch, adam.current_step_lr()))
                    # epoch:0, current lr is 0.5
                    # epoch:1, current lr is 0.5
                    # epoch:2, current lr is 0.5
                    # epoch:3, current lr is 0.05
                    # epoch:4, current lr is 0.05
                    # epoch:5, current lr is 0.05
                    # epoch:6, current lr is 0.005
                    # epoch:7, current lr is 0.005
                    # epoch:8, current lr is 0.005

    """

    def __init__(self, learning_rate, step_size, decay_rate=0.1):
        if not isinstance(step_size, int):
            raise TypeError(
                "The type of 'step_size' must be 'int', but received %s." %
                type(step_size))
        if decay_rate >= 1.0:
            raise ValueError('decay_rate should be < 1.0.')

        self.step_size = step_size
        self.decay_rate = decay_rate
        super(StepDecay, self).__init__(learning_rate)

    def get_lr(self):
        decay_rate = self.create_lr_var(self.decay_rate)
        i = self.epoch_num // self.step_size
        return self.base_lr * (decay_rate**i)


class MultiStepDecay(_LearningRateEpochDecay):
    """
    :api_attr: imperative

    Decays the learning rate of ``optimizer`` by ``decay_rate`` once ``epoch`` reaches one of the milestones.

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5
        milestones = [30, 50]
        decay_rate = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005

    Parameters:
1094
        learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
        decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` . 
            It should be less than 1.0. Default: 0.1.

    Returns:
        None.

    Examples:
        .. code-block:: python
            
            import paddle.fluid as fluid
            import numpy as np
            with fluid.dygraph.guard():
                x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
                linear = fluid.dygraph.Linear(10, 10)
                input = fluid.dygraph.to_variable(x)
                scheduler = fluid.dygraph.MultiStepDecay(0.5, milestones=[3, 5])
                adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters())

                for epoch in range(6):
                    for batch_id in range(5):
                        out = linear(input)
                        loss = fluid.layers.reduce_mean(out)
                        adam.minimize(loss)
                    scheduler.epoch()

                    print("epoch:{}, current lr is {}" .format(epoch, adam.current_step_lr()))
                    # epoch:0, current lr is 0.5
                    # epoch:1, current lr is 0.5
                    # epoch:2, current lr is 0.5
                    # epoch:3, current lr is 0.05
                    # epoch:4, current lr is 0.05
                    # epoch:5, current lr is 0.005

    """

    def __init__(self, learning_rate, milestones, decay_rate=0.1):
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
                % type(milestones))

        if not all([
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
        ]):
            raise ValueError('The elements of milestones must be incremented')
        if decay_rate >= 1.0:
            raise ValueError('decay_rate should be < 1.0.')

        self.milestones = milestones
        self.decay_rate = decay_rate
        super(MultiStepDecay, self).__init__(learning_rate)

    def get_lr(self):
        decay_rate = self.create_lr_var(self.decay_rate)
        for i in range(len(self.milestones)):
            if self.epoch_num < self.milestones[i]:
                return self.base_lr * (decay_rate**i)

        return self.base_lr * (decay_rate**len(self.milestones))
1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222


class LambdaDecay(_LearningRateEpochDecay):
    """
    :api_attr: imperative

    Sets the learning rate of ``optimizer`` to the initial lr times a multiplicative factor, and this multiplicative
    factor is computed by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95 ** epoch

        learning_rate = 0.5        # epoch 0
        learning_rate = 0.475      # epoch 1
        learning_rate = 0.45125    # epoch 2

    Parameters:
        learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
        lr_lambda (function): A function which computes a multiplicative factor given an integer parameter ``epoch`` , and 
            then multiply the initial learning rate by this multiplicative factor.
    
    Returns:
        None.

    Examples:
        .. code-block:: python
            
            import paddle.fluid as fluid
            import numpy as np
            with fluid.dygraph.guard():
                x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
                linear = fluid.dygraph.Linear(10, 10)
                input = fluid.dygraph.to_variable(x)
                scheduler = fluid.dygraph.LambdaDecay(0.5, lr_lambda=lambda x: 0.95**x)
                adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters())

                for epoch in range(6):
                    for batch_id in range(5):
                        out = linear(input)
                        loss = fluid.layers.reduce_mean(out)
                        adam.minimize(loss)
                    scheduler.epoch()

                    print("epoch:%d, current lr is %f" .format(epoch, adam.current_step_lr()))
                    # epoch:0, current lr is 0.5
                    # epoch:1, current lr is 0.475
                    # epoch:2, current lr is 0.45125

    """

    def __init__(self, learning_rate, lr_lambda):
        if not callable(lr_lambda):
            raise TypeError(
                "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s."
                % type(lr_lambda))

        self.lr_lambda = lr_lambda
        super(LambdaDecay, self).__init__(learning_rate)

    def get_lr(self):
        base_lr = self.create_lr_var(self.base_lr)

        return self.base_lr * self.lr_lambda(self.epoch_num)