未验证 提交 4a3a2d6b 编写于 作者: G guguguzi 提交者: GitHub

Add api MultiplicativeDecay (#38250)

* delete the modification of dygraph

* CI

* check CI

* modify the retrun value of get_lr
上级 c8fbd3cd
......@@ -205,6 +205,13 @@ def lambda_lr(epoch_num, learning_rate, lr_lambda, verbose=False):
return learning_rate * lr_lambda(epoch_num)
def multiplicative_lr(epoch_num, learning_rate, lr_lambda, verbose=False):
latest_lr = learning_rate
for i in range(epoch_num):
latest_lr = latest_lr * lr_lambda(i + 1)
return latest_lr
def piecewise_lr(epoch_num, boundaries, values, verbose=False):
assert len(boundaries) + 1 == len(values)
for i in range(len(boundaries)):
......@@ -519,6 +526,10 @@ class TestLRScheduler(unittest.TestCase):
"learning_rate": 0.5,
"lr_lambda": lambda x: 0.95**x,
"verbose": True
}), (multiplicative_lr, paddle.optimizer.lr.MultiplicativeDecay, {
"learning_rate": 0.5,
"lr_lambda": lambda x: 0.95,
"verbose": True
}), (cosine_annealing_lr, paddle.optimizer.lr.CosineAnnealingDecay, {
"learning_rate": 0.5,
"T_max": 10,
......
......@@ -17,7 +17,7 @@ import numpy
import warnings
from paddle import Tensor
__all__ = [ #noqa
__all__ = [ # noqa
'LRScheduler',
'NoamDecay',
'PiecewiseDecay',
......@@ -30,7 +30,8 @@ __all__ = [ #noqa
'StepDecay',
'LambdaDecay',
'ReduceOnPlateau',
'CosineAnnealingDecay'
'CosineAnnealingDecay',
'MultiplicativeDecay'
]
......@@ -55,9 +56,9 @@ class LRScheduler(object):
Examples:
Here is an example of a simple ``StepDecay`` implementation.
.. code-block:: python
import paddle
from paddle.optimizer.lr import LRScheduler
......@@ -99,7 +100,7 @@ class LRScheduler(object):
self.step()
def __call__(self):
"""
"""
Return lastest computed learning rate on current epoch.
"""
return self.last_lr
......@@ -107,7 +108,7 @@ class LRScheduler(object):
def step(self, epoch=None):
"""
``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` .
Args:
......@@ -191,7 +192,7 @@ class LRScheduler(object):
def get_lr(self):
"""
For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` .
Otherwise, an ``NotImplementedError`` exception will be thrown.
......@@ -203,7 +204,7 @@ class LRScheduler(object):
class NoamDecay(LRScheduler):
r"""
Applies Noam Decay to the initial learning rate.
Applies Noam Decay to the initial learning rate.
The algorithm can be described as following.
......@@ -211,7 +212,7 @@ class NoamDecay(LRScheduler):
new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})
Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
Args:
......@@ -312,8 +313,8 @@ class PiecewiseDecay(LRScheduler):
learning_rate = 0.1
Args:
boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int.
values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries.
boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int.
values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries.
The type of element in the list is python float.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
......@@ -322,7 +323,7 @@ class PiecewiseDecay(LRScheduler):
``PiecewiseDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
......@@ -388,7 +389,7 @@ class NaturalExpDecay(LRScheduler):
r"""
Applies natural exponential decay to the initial learning rate.
The algorithm can be described as following:
.. math::
......@@ -405,7 +406,7 @@ class NaturalExpDecay(LRScheduler):
``NaturalExpDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
......@@ -476,7 +477,7 @@ class InverseTimeDecay(LRScheduler):
Args:
learning_rate (float): The initial learning rate. It is a python float number.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
......@@ -485,7 +486,7 @@ class InverseTimeDecay(LRScheduler):
``InverseTimeDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
......@@ -555,7 +556,7 @@ class PolynomialDecay(LRScheduler):
.. math::
decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps})
decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps})
new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
......@@ -563,7 +564,7 @@ class PolynomialDecay(LRScheduler):
.. math::
epoch & = min(epoch, decay\_steps)
epoch & = min(epoch, decay\_steps)
new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
......@@ -573,7 +574,7 @@ class PolynomialDecay(LRScheduler):
decay_steps(int): The decay step size. It determines the decay cycle. It must be a positive integer.
end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
power(float, optional): Power of polynomial. Default: 1.0.
cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
to ``end_lr`` . If False, the learning rate is monotone decreasing. Default: False.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
......@@ -582,7 +583,7 @@ class PolynomialDecay(LRScheduler):
``PolynomialDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
......@@ -671,21 +672,21 @@ class LinearWarmup(LRScheduler):
Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
When epoch < warmup_steps, learning rate is updated as:
.. math::
lr = start\_lr + (end\_lr - start\_lr) * \frac{epoch}{warmup\_steps}
where start_lr is the initial learning rate, and end_lr is the final learning rate;
When epoch >= warmup_steps, learning rate is updated as:
.. math::
lr = learning_rate
where ``learning_rate`` is float or any subclass of ``LRScheduler`` .
Args:
......@@ -700,7 +701,7 @@ class LinearWarmup(LRScheduler):
``LinearWarmup`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
......@@ -811,14 +812,14 @@ class ExponentialDecay(LRScheduler):
Update learning rate by `gamma` each epoch.
The algorithm can be described as following.
.. math::
new\_learning\_rate = last\_learning\_rate * gamma
Args:
learning_rate (float): The initial learning rate. It is a python float number.
gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
......@@ -827,7 +828,7 @@ class ExponentialDecay(LRScheduler):
``ExponentialDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
......@@ -889,7 +890,7 @@ class MultiStepDecay(LRScheduler):
"""
Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
The algorithm can be described as the code below.
The algorithm can be described as the code below.
.. code-block:: text
......@@ -906,17 +907,17 @@ class MultiStepDecay(LRScheduler):
Args:
learning_rate (float): The initial learning rate. It is a python float number.
milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``MultiStepDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
......@@ -999,7 +1000,7 @@ class StepDecay(LRScheduler):
"""
Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.
The algorithm can be described as the code below.
The algorithm can be described as the code below.
.. code-block:: text
......@@ -1015,7 +1016,7 @@ class StepDecay(LRScheduler):
Args:
learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update. It must be a positive integer.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
......@@ -1025,7 +1026,7 @@ class StepDecay(LRScheduler):
Examples:
.. code-block:: python
import paddle
......@@ -1102,7 +1103,7 @@ class LambdaDecay(LRScheduler):
"""
Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .
The algorithm can be described as the code below.
The algorithm can be described as the code below.
.. code-block:: text
......@@ -1118,12 +1119,12 @@ class LambdaDecay(LRScheduler):
lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``LambdaDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
......@@ -1188,37 +1189,37 @@ class LambdaDecay(LRScheduler):
class ReduceOnPlateau(LRScheduler):
"""
Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate
Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate
by 2 to 10 times once model performance has no longer improvement.
The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics``
stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` .
(Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience``
The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics``
stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` .
(Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience``
number of epochs, the learning rate will be reduced.)
In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.
Args:
learning_rate (float): The initial learning rate. It is a python float number.
mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` , the learning
mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` , the learning
rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
It should be less than 1.0. Default: 0.1.
patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
Default: 10.
threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon,
epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon,
the update is ignored. Default: 1e-8.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.
Returns:
``ReduceOnPlateau`` instance to schedule learning rate.
......@@ -1331,18 +1332,18 @@ class ReduceOnPlateau(LRScheduler):
def step(self, metrics, epoch=None):
"""
step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .
step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .
The new learning rate will take effect on next epoch.
Args:
metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce.
metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce.
If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
'numpy.ndarray', its shape must be [1].
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
Examples:
Please refer to the example of current LRScheduler.
"""
......@@ -1354,8 +1355,9 @@ class ReduceOnPlateau(LRScheduler):
# loss must be float, numpy.ndarray or 1-D Tensor with shape [1]
if isinstance(metrics, (Tensor, numpy.ndarray)):
assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \
"should be (1L,), but the current metrics.shape is {}. Maybe that " \
"you should call paddle.mean to process it first.".format(metrics.shape)
"should be (1L,), but the current metrics.shape is {}. Maybe that " \
"you should call paddle.mean to process it first.".format(
metrics.shape)
elif not isinstance(metrics,
(int, float, numpy.float32, numpy.float64)):
raise TypeError(
......@@ -1399,8 +1401,8 @@ class ReduceOnPlateau(LRScheduler):
class CosineAnnealingDecay(LRScheduler):
r"""
Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to
the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in
Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to
the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in
SGDR.
The algorithm can be described as following.
......@@ -1409,15 +1411,15 @@ class CosineAnnealingDecay(LRScheduler):
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
& T_{cur} \neq (2k+1)T_{max};
& T_{cur} \neq (2k+1)T_{max};
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
& T_{cur} = (2k+1)T_{max}.
It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
Note that this only implements the cosine annealing part of SGDR, and not the restarts.
Args:
learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. It must be a positive integer.
......@@ -1429,7 +1431,7 @@ class CosineAnnealingDecay(LRScheduler):
``CosineAnnealingDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
......@@ -1513,3 +1515,68 @@ class CosineAnnealingDecay(LRScheduler):
def _get_closed_form_lr(self):
return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
math.pi * self.last_epoch / self.T_max)) / 2
class MultiplicativeDecay(LRScheduler):
"""
Multiply the learning rate of ``optimizer`` by the factor given in function ``lr_lambda`` .
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5 # init learning_rate
lr_lambda = lambda epoch: 0.95
learning_rate = 0.5 # epoch 0,
learning_rate = 0.475 # epoch 1, 0.5*0.95
learning_rate = 0.45125 # epoch 2, 0.475*0.95
Args:
learning_rate (float): The initial learning rate. It is a python float number.
lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the last learning rate by this factor.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``MultiplicativeDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dynamic graph mode
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.lr.MultiplicativeDecay(learning_rate=0.5, lr_lambda=lambda x:0.95, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
for epoch in range(20):
for batch_id in range(5):
x = paddle.uniform([10, 10])
out = linear(x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_gradients()
scheduler.step() # If you update learning rate each step
# scheduler.step() # If you update learning rate each epoch
"""
def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
if not callable(lr_lambda):
raise TypeError(
"The type of 'lr_lambda' in 'MultiplicativeDecay' must be 'function', but received %s."
% type(lr_lambda))
self.lr_lambda = lr_lambda
super(MultiplicativeDecay, self).__init__(learning_rate, last_epoch,
verbose)
def get_lr(self):
if self.last_epoch > 0:
return self.last_lr * self.lr_lambda(self.last_epoch)
else:
return self.base_lr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册