未验证 提交 4a3a2d6b 编写于 作者: G guguguzi 提交者: GitHub

Add api MultiplicativeDecay (#38250)

* delete the modification of dygraph

* CI

* check CI

* modify the retrun value of get_lr
上级 c8fbd3cd
...@@ -205,6 +205,13 @@ def lambda_lr(epoch_num, learning_rate, lr_lambda, verbose=False): ...@@ -205,6 +205,13 @@ def lambda_lr(epoch_num, learning_rate, lr_lambda, verbose=False):
return learning_rate * lr_lambda(epoch_num) return learning_rate * lr_lambda(epoch_num)
def multiplicative_lr(epoch_num, learning_rate, lr_lambda, verbose=False):
latest_lr = learning_rate
for i in range(epoch_num):
latest_lr = latest_lr * lr_lambda(i + 1)
return latest_lr
def piecewise_lr(epoch_num, boundaries, values, verbose=False): def piecewise_lr(epoch_num, boundaries, values, verbose=False):
assert len(boundaries) + 1 == len(values) assert len(boundaries) + 1 == len(values)
for i in range(len(boundaries)): for i in range(len(boundaries)):
...@@ -519,6 +526,10 @@ class TestLRScheduler(unittest.TestCase): ...@@ -519,6 +526,10 @@ class TestLRScheduler(unittest.TestCase):
"learning_rate": 0.5, "learning_rate": 0.5,
"lr_lambda": lambda x: 0.95**x, "lr_lambda": lambda x: 0.95**x,
"verbose": True "verbose": True
}), (multiplicative_lr, paddle.optimizer.lr.MultiplicativeDecay, {
"learning_rate": 0.5,
"lr_lambda": lambda x: 0.95,
"verbose": True
}), (cosine_annealing_lr, paddle.optimizer.lr.CosineAnnealingDecay, { }), (cosine_annealing_lr, paddle.optimizer.lr.CosineAnnealingDecay, {
"learning_rate": 0.5, "learning_rate": 0.5,
"T_max": 10, "T_max": 10,
......
...@@ -17,7 +17,7 @@ import numpy ...@@ -17,7 +17,7 @@ import numpy
import warnings import warnings
from paddle import Tensor from paddle import Tensor
__all__ = [ #noqa __all__ = [ # noqa
'LRScheduler', 'LRScheduler',
'NoamDecay', 'NoamDecay',
'PiecewiseDecay', 'PiecewiseDecay',
...@@ -30,7 +30,8 @@ __all__ = [ #noqa ...@@ -30,7 +30,8 @@ __all__ = [ #noqa
'StepDecay', 'StepDecay',
'LambdaDecay', 'LambdaDecay',
'ReduceOnPlateau', 'ReduceOnPlateau',
'CosineAnnealingDecay' 'CosineAnnealingDecay',
'MultiplicativeDecay'
] ]
...@@ -55,9 +56,9 @@ class LRScheduler(object): ...@@ -55,9 +56,9 @@ class LRScheduler(object):
Examples: Examples:
Here is an example of a simple ``StepDecay`` implementation. Here is an example of a simple ``StepDecay`` implementation.
.. code-block:: python .. code-block:: python
import paddle import paddle
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
...@@ -99,7 +100,7 @@ class LRScheduler(object): ...@@ -99,7 +100,7 @@ class LRScheduler(object):
self.step() self.step()
def __call__(self): def __call__(self):
""" """
Return lastest computed learning rate on current epoch. Return lastest computed learning rate on current epoch.
""" """
return self.last_lr return self.last_lr
...@@ -107,7 +108,7 @@ class LRScheduler(object): ...@@ -107,7 +108,7 @@ class LRScheduler(object):
def step(self, epoch=None): def step(self, epoch=None):
""" """
``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` . ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` . The new learning rate will take effect on next ``optimizer.step`` .
Args: Args:
...@@ -191,7 +192,7 @@ class LRScheduler(object): ...@@ -191,7 +192,7 @@ class LRScheduler(object):
def get_lr(self): def get_lr(self):
""" """
For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` . For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` .
Otherwise, an ``NotImplementedError`` exception will be thrown. Otherwise, an ``NotImplementedError`` exception will be thrown.
...@@ -203,7 +204,7 @@ class LRScheduler(object): ...@@ -203,7 +204,7 @@ class LRScheduler(object):
class NoamDecay(LRScheduler): class NoamDecay(LRScheduler):
r""" r"""
Applies Noam Decay to the initial learning rate. Applies Noam Decay to the initial learning rate.
The algorithm can be described as following. The algorithm can be described as following.
...@@ -211,7 +212,7 @@ class NoamDecay(LRScheduler): ...@@ -211,7 +212,7 @@ class NoamDecay(LRScheduler):
new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5}) new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})
Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_ Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
Args: Args:
...@@ -312,8 +313,8 @@ class PiecewiseDecay(LRScheduler): ...@@ -312,8 +313,8 @@ class PiecewiseDecay(LRScheduler):
learning_rate = 0.1 learning_rate = 0.1
Args: Args:
boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int. boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int.
values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries. values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries.
The type of element in the list is python float. The type of element in the list is python float.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
...@@ -322,7 +323,7 @@ class PiecewiseDecay(LRScheduler): ...@@ -322,7 +323,7 @@ class PiecewiseDecay(LRScheduler):
``PiecewiseDecay`` instance to schedule learning rate. ``PiecewiseDecay`` instance to schedule learning rate.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -388,7 +389,7 @@ class NaturalExpDecay(LRScheduler): ...@@ -388,7 +389,7 @@ class NaturalExpDecay(LRScheduler):
r""" r"""
Applies natural exponential decay to the initial learning rate. Applies natural exponential decay to the initial learning rate.
The algorithm can be described as following: The algorithm can be described as following:
.. math:: .. math::
...@@ -405,7 +406,7 @@ class NaturalExpDecay(LRScheduler): ...@@ -405,7 +406,7 @@ class NaturalExpDecay(LRScheduler):
``NaturalExpDecay`` instance to schedule learning rate. ``NaturalExpDecay`` instance to schedule learning rate.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -476,7 +477,7 @@ class InverseTimeDecay(LRScheduler): ...@@ -476,7 +477,7 @@ class InverseTimeDecay(LRScheduler):
Args: Args:
learning_rate (float): The initial learning rate. It is a python float number. learning_rate (float): The initial learning rate. It is a python float number.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1. It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
...@@ -485,7 +486,7 @@ class InverseTimeDecay(LRScheduler): ...@@ -485,7 +486,7 @@ class InverseTimeDecay(LRScheduler):
``InverseTimeDecay`` instance to schedule learning rate. ``InverseTimeDecay`` instance to schedule learning rate.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -555,7 +556,7 @@ class PolynomialDecay(LRScheduler): ...@@ -555,7 +556,7 @@ class PolynomialDecay(LRScheduler):
.. math:: .. math::
decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps}) decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps})
new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
...@@ -563,7 +564,7 @@ class PolynomialDecay(LRScheduler): ...@@ -563,7 +564,7 @@ class PolynomialDecay(LRScheduler):
.. math:: .. math::
epoch & = min(epoch, decay\_steps) epoch & = min(epoch, decay\_steps)
new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
...@@ -573,7 +574,7 @@ class PolynomialDecay(LRScheduler): ...@@ -573,7 +574,7 @@ class PolynomialDecay(LRScheduler):
decay_steps(int): The decay step size. It determines the decay cycle. It must be a positive integer. decay_steps(int): The decay step size. It determines the decay cycle. It must be a positive integer.
end_lr(float, optional): The minimum final learning rate. Default: 0.0001. end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
power(float, optional): Power of polynomial. Default: 1.0. power(float, optional): Power of polynomial. Default: 1.0.
cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
to ``end_lr`` . If False, the learning rate is monotone decreasing. Default: False. to ``end_lr`` . If False, the learning rate is monotone decreasing. Default: False.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
...@@ -582,7 +583,7 @@ class PolynomialDecay(LRScheduler): ...@@ -582,7 +583,7 @@ class PolynomialDecay(LRScheduler):
``PolynomialDecay`` instance to schedule learning rate. ``PolynomialDecay`` instance to schedule learning rate.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -671,21 +672,21 @@ class LinearWarmup(LRScheduler): ...@@ -671,21 +672,21 @@ class LinearWarmup(LRScheduler):
Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler. Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_ For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
When epoch < warmup_steps, learning rate is updated as: When epoch < warmup_steps, learning rate is updated as:
.. math:: .. math::
lr = start\_lr + (end\_lr - start\_lr) * \frac{epoch}{warmup\_steps} lr = start\_lr + (end\_lr - start\_lr) * \frac{epoch}{warmup\_steps}
where start_lr is the initial learning rate, and end_lr is the final learning rate; where start_lr is the initial learning rate, and end_lr is the final learning rate;
When epoch >= warmup_steps, learning rate is updated as: When epoch >= warmup_steps, learning rate is updated as:
.. math:: .. math::
lr = learning_rate lr = learning_rate
where ``learning_rate`` is float or any subclass of ``LRScheduler`` . where ``learning_rate`` is float or any subclass of ``LRScheduler`` .
Args: Args:
...@@ -700,7 +701,7 @@ class LinearWarmup(LRScheduler): ...@@ -700,7 +701,7 @@ class LinearWarmup(LRScheduler):
``LinearWarmup`` instance to schedule learning rate. ``LinearWarmup`` instance to schedule learning rate.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -811,14 +812,14 @@ class ExponentialDecay(LRScheduler): ...@@ -811,14 +812,14 @@ class ExponentialDecay(LRScheduler):
Update learning rate by `gamma` each epoch. Update learning rate by `gamma` each epoch.
The algorithm can be described as following. The algorithm can be described as following.
.. math:: .. math::
new\_learning\_rate = last\_learning\_rate * gamma new\_learning\_rate = last\_learning\_rate * gamma
Args: Args:
learning_rate (float): The initial learning rate. It is a python float number. learning_rate (float): The initial learning rate. It is a python float number.
gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. It should be less than 1.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
...@@ -827,7 +828,7 @@ class ExponentialDecay(LRScheduler): ...@@ -827,7 +828,7 @@ class ExponentialDecay(LRScheduler):
``ExponentialDecay`` instance to schedule learning rate. ``ExponentialDecay`` instance to schedule learning rate.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -889,7 +890,7 @@ class MultiStepDecay(LRScheduler): ...@@ -889,7 +890,7 @@ class MultiStepDecay(LRScheduler):
""" """
Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones. Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
The algorithm can be described as the code below. The algorithm can be described as the code below.
.. code-block:: text .. code-block:: text
...@@ -906,17 +907,17 @@ class MultiStepDecay(LRScheduler): ...@@ -906,17 +907,17 @@ class MultiStepDecay(LRScheduler):
Args: Args:
learning_rate (float): The initial learning rate. It is a python float number. learning_rate (float): The initial learning rate. It is a python float number.
milestones (tuple|list): List or tuple of each boundaries. Must be increasing. milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1. It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns: Returns:
``MultiStepDecay`` instance to schedule learning rate. ``MultiStepDecay`` instance to schedule learning rate.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -999,7 +1000,7 @@ class StepDecay(LRScheduler): ...@@ -999,7 +1000,7 @@ class StepDecay(LRScheduler):
""" """
Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch. Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.
The algorithm can be described as the code below. The algorithm can be described as the code below.
.. code-block:: text .. code-block:: text
...@@ -1015,7 +1016,7 @@ class StepDecay(LRScheduler): ...@@ -1015,7 +1016,7 @@ class StepDecay(LRScheduler):
Args: Args:
learning_rate (float): The initial learning rate. It is a python float number. learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update. It must be a positive integer. step_size (int): the interval to update. It must be a positive integer.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1. It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
...@@ -1025,7 +1026,7 @@ class StepDecay(LRScheduler): ...@@ -1025,7 +1026,7 @@ class StepDecay(LRScheduler):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -1102,7 +1103,7 @@ class LambdaDecay(LRScheduler): ...@@ -1102,7 +1103,7 @@ class LambdaDecay(LRScheduler):
""" """
Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` . Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .
The algorithm can be described as the code below. The algorithm can be described as the code below.
.. code-block:: text .. code-block:: text
...@@ -1118,12 +1119,12 @@ class LambdaDecay(LRScheduler): ...@@ -1118,12 +1119,12 @@ class LambdaDecay(LRScheduler):
lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor. lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns: Returns:
``LambdaDecay`` instance to schedule learning rate. ``LambdaDecay`` instance to schedule learning rate.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -1188,37 +1189,37 @@ class LambdaDecay(LRScheduler): ...@@ -1188,37 +1189,37 @@ class LambdaDecay(LRScheduler):
class ReduceOnPlateau(LRScheduler): class ReduceOnPlateau(LRScheduler):
""" """
Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate
by 2 to 10 times once model performance has no longer improvement. by 2 to 10 times once model performance has no longer improvement.
The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics`` The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics``
stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` . stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` .
(Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience`` (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience``
number of epochs, the learning rate will be reduced.) number of epochs, the learning rate will be reduced.)
In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation. In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.
Args: Args:
learning_rate (float): The initial learning rate. It is a python float number. learning_rate (float): The initial learning rate. It is a python float number.
mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` , the learning learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` , the learning
rate will reduce when ``loss`` stops ascending. Default: ``'min'`` . rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` . factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
It should be less than 1.0. Default: 0.1. It should be less than 1.0. Default: 0.1.
patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced. patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
Default: 10. Default: 10.
threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` . threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
This make tiny changes of ``loss`` will be ignored. Default: 1e-4. This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss`` threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
change of ``loss`` is ``threshold`` . Default: ``'rel'`` . change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0. cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0. min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon, epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon,
the update is ignored. Default: 1e-8. the update is ignored. Default: 1e-8.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.
Returns: Returns:
``ReduceOnPlateau`` instance to schedule learning rate. ``ReduceOnPlateau`` instance to schedule learning rate.
...@@ -1331,18 +1332,18 @@ class ReduceOnPlateau(LRScheduler): ...@@ -1331,18 +1332,18 @@ class ReduceOnPlateau(LRScheduler):
def step(self, metrics, epoch=None): def step(self, metrics, epoch=None):
""" """
step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` . step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .
The new learning rate will take effect on next epoch. The new learning rate will take effect on next epoch.
Args: Args:
metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce. metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce.
If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
'numpy.ndarray', its shape must be [1]. 'numpy.ndarray', its shape must be [1].
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1. epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns: Returns:
None None
Examples: Examples:
Please refer to the example of current LRScheduler. Please refer to the example of current LRScheduler.
""" """
...@@ -1354,8 +1355,9 @@ class ReduceOnPlateau(LRScheduler): ...@@ -1354,8 +1355,9 @@ class ReduceOnPlateau(LRScheduler):
# loss must be float, numpy.ndarray or 1-D Tensor with shape [1] # loss must be float, numpy.ndarray or 1-D Tensor with shape [1]
if isinstance(metrics, (Tensor, numpy.ndarray)): if isinstance(metrics, (Tensor, numpy.ndarray)):
assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \ assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \
"should be (1L,), but the current metrics.shape is {}. Maybe that " \ "should be (1L,), but the current metrics.shape is {}. Maybe that " \
"you should call paddle.mean to process it first.".format(metrics.shape) "you should call paddle.mean to process it first.".format(
metrics.shape)
elif not isinstance(metrics, elif not isinstance(metrics,
(int, float, numpy.float32, numpy.float64)): (int, float, numpy.float32, numpy.float64)):
raise TypeError( raise TypeError(
...@@ -1399,8 +1401,8 @@ class ReduceOnPlateau(LRScheduler): ...@@ -1399,8 +1401,8 @@ class ReduceOnPlateau(LRScheduler):
class CosineAnnealingDecay(LRScheduler): class CosineAnnealingDecay(LRScheduler):
r""" r"""
Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to
the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in
SGDR. SGDR.
The algorithm can be described as following. The algorithm can be described as following.
...@@ -1409,15 +1411,15 @@ class CosineAnnealingDecay(LRScheduler): ...@@ -1409,15 +1411,15 @@ class CosineAnnealingDecay(LRScheduler):
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
& T_{cur} \neq (2k+1)T_{max}; & T_{cur} \neq (2k+1)T_{max};
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
& T_{cur} = (2k+1)T_{max}. & T_{cur} = (2k+1)T_{max}.
It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_. It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
Note that this only implements the cosine annealing part of SGDR, and not the restarts. Note that this only implements the cosine annealing part of SGDR, and not the restarts.
Args: Args:
learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number. learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. It must be a positive integer. T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. It must be a positive integer.
...@@ -1429,7 +1431,7 @@ class CosineAnnealingDecay(LRScheduler): ...@@ -1429,7 +1431,7 @@ class CosineAnnealingDecay(LRScheduler):
``CosineAnnealingDecay`` instance to schedule learning rate. ``CosineAnnealingDecay`` instance to schedule learning rate.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -1513,3 +1515,68 @@ class CosineAnnealingDecay(LRScheduler): ...@@ -1513,3 +1515,68 @@ class CosineAnnealingDecay(LRScheduler):
def _get_closed_form_lr(self): def _get_closed_form_lr(self):
return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos( return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
math.pi * self.last_epoch / self.T_max)) / 2 math.pi * self.last_epoch / self.T_max)) / 2
class MultiplicativeDecay(LRScheduler):
"""
Multiply the learning rate of ``optimizer`` by the factor given in function ``lr_lambda`` .
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5 # init learning_rate
lr_lambda = lambda epoch: 0.95
learning_rate = 0.5 # epoch 0,
learning_rate = 0.475 # epoch 1, 0.5*0.95
learning_rate = 0.45125 # epoch 2, 0.475*0.95
Args:
learning_rate (float): The initial learning rate. It is a python float number.
lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the last learning rate by this factor.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``MultiplicativeDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dynamic graph mode
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.lr.MultiplicativeDecay(learning_rate=0.5, lr_lambda=lambda x:0.95, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
for epoch in range(20):
for batch_id in range(5):
x = paddle.uniform([10, 10])
out = linear(x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_gradients()
scheduler.step() # If you update learning rate each step
# scheduler.step() # If you update learning rate each epoch
"""
def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
if not callable(lr_lambda):
raise TypeError(
"The type of 'lr_lambda' in 'MultiplicativeDecay' must be 'function', but received %s."
% type(lr_lambda))
self.lr_lambda = lr_lambda
super(MultiplicativeDecay, self).__init__(learning_rate, last_epoch,
verbose)
def get_lr(self):
if self.last_epoch > 0:
return self.last_lr * self.lr_lambda(self.last_epoch)
else:
return self.base_lr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册