未验证 提交 d171f373 编写于 作者: Z Zhou Wei 提交者: GitHub

[CHERR-PICK1.8]add base class LearningRateEpochDecay, and MultiStepDecay, StepDecay (#25277)

* CHERR-PICK1.8,add base class of LearningRateEpochDecay, and API: MultiStepDecay, and API: StepDecay,test=release/1.8

* fix unittest to add coverage,test=develop
上级 43facfd3
...@@ -23,7 +23,7 @@ from ..data_feeder import check_type ...@@ -23,7 +23,7 @@ from ..data_feeder import check_type
__all__ = [ __all__ = [
'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay', 'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay',
'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay', 'LinearLrWarmup', 'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay', 'LinearLrWarmup',
'ReduceLROnPlateau' 'ReduceLROnPlateau', 'StepDecay', 'MultiStepDecay'
] ]
...@@ -72,6 +72,8 @@ class LearningRateDecay(object): ...@@ -72,6 +72,8 @@ class LearningRateDecay(object):
class PiecewiseDecay(LearningRateDecay): class PiecewiseDecay(LearningRateDecay):
""" """
:api_attr: imperative
Piecewise decay scheduler. Piecewise decay scheduler.
The algorithm can be described as the code below. The algorithm can be described as the code below.
...@@ -131,6 +133,8 @@ class PiecewiseDecay(LearningRateDecay): ...@@ -131,6 +133,8 @@ class PiecewiseDecay(LearningRateDecay):
class NaturalExpDecay(LearningRateDecay): class NaturalExpDecay(LearningRateDecay):
""" """
:api_attr: imperative
Applies natural exponential decay to the initial learning rate. Applies natural exponential decay to the initial learning rate.
The algorithm can be described as following. The algorithm can be described as following.
...@@ -183,7 +187,6 @@ class NaturalExpDecay(LearningRateDecay): ...@@ -183,7 +187,6 @@ class NaturalExpDecay(LearningRateDecay):
staircase=True), staircase=True),
parameter_list=emb.parameters()) parameter_list=emb.parameters())
""" """
def __init__(self, def __init__(self,
...@@ -213,6 +216,8 @@ class NaturalExpDecay(LearningRateDecay): ...@@ -213,6 +216,8 @@ class NaturalExpDecay(LearningRateDecay):
class ExponentialDecay(LearningRateDecay): class ExponentialDecay(LearningRateDecay):
""" """
:api_attr: imperative
Applies exponential decay to the learning rate. Applies exponential decay to the learning rate.
The algorithm can be described as following. The algorithm can be described as following.
...@@ -293,6 +298,8 @@ class ExponentialDecay(LearningRateDecay): ...@@ -293,6 +298,8 @@ class ExponentialDecay(LearningRateDecay):
class InverseTimeDecay(LearningRateDecay): class InverseTimeDecay(LearningRateDecay):
""" """
:api_attr: imperative
Applies inverse time decay to the initial learning rate. Applies inverse time decay to the initial learning rate.
The algorithm can be described as following. The algorithm can be described as following.
...@@ -369,6 +376,8 @@ class InverseTimeDecay(LearningRateDecay): ...@@ -369,6 +376,8 @@ class InverseTimeDecay(LearningRateDecay):
class PolynomialDecay(LearningRateDecay): class PolynomialDecay(LearningRateDecay):
""" """
:api_attr: imperative
Applies polynomial decay to the initial learning rate. Applies polynomial decay to the initial learning rate.
The algorithm can be described as following. The algorithm can be described as following.
...@@ -461,6 +470,8 @@ class PolynomialDecay(LearningRateDecay): ...@@ -461,6 +470,8 @@ class PolynomialDecay(LearningRateDecay):
class CosineDecay(LearningRateDecay): class CosineDecay(LearningRateDecay):
""" """
:api_attr: imperative
Applies cosine decay to the learning rate. Applies cosine decay to the learning rate.
The algorithm can be described as following. The algorithm can be described as following.
...@@ -517,6 +528,8 @@ class CosineDecay(LearningRateDecay): ...@@ -517,6 +528,8 @@ class CosineDecay(LearningRateDecay):
class NoamDecay(LearningRateDecay): class NoamDecay(LearningRateDecay):
""" """
:api_attr: imperative
Applies Noam decay to the initial learning rate. Applies Noam decay to the initial learning rate.
The algorithm can be described as following. The algorithm can be described as following.
...@@ -582,6 +595,8 @@ class NoamDecay(LearningRateDecay): ...@@ -582,6 +595,8 @@ class NoamDecay(LearningRateDecay):
class LinearLrWarmup(LearningRateDecay): class LinearLrWarmup(LearningRateDecay):
""" """
:api_attr: imperative
This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling. This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_ For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
...@@ -670,6 +685,8 @@ class LinearLrWarmup(LearningRateDecay): ...@@ -670,6 +685,8 @@ class LinearLrWarmup(LearningRateDecay):
class ReduceLROnPlateau(LearningRateDecay): class ReduceLROnPlateau(LearningRateDecay):
""" """
:api_attr: imperative
Reduce learning rate when ``loss`` has stopped descending. Models often benefit from reducing the learning rate Reduce learning rate when ``loss`` has stopped descending. Models often benefit from reducing the learning rate
by 2 to 10 times once model performance has no longer improvement. by 2 to 10 times once model performance has no longer improvement.
...@@ -774,7 +791,6 @@ class ReduceLROnPlateau(LearningRateDecay): ...@@ -774,7 +791,6 @@ class ReduceLROnPlateau(LearningRateDecay):
raise ValueError('threshold mode ' + threshold_mode + raise ValueError('threshold mode ' + threshold_mode +
' is unknown!') ' is unknown!')
self.threshold_mode = threshold_mode self.threshold_mode = threshold_mode
check_type(learning_rate, 'learning_rate', (float, int, Variable), check_type(learning_rate, 'learning_rate', (float, int, Variable),
'ReduceLROnPlateau') 'ReduceLROnPlateau')
if isinstance(learning_rate, (float, int)): if isinstance(learning_rate, (float, int)):
...@@ -856,3 +872,217 @@ class ReduceLROnPlateau(LearningRateDecay): ...@@ -856,3 +872,217 @@ class ReduceLROnPlateau(LearningRateDecay):
else: else:
return current > best + self.threshold return current > best + self.threshold
class _LearningRateEpochDecay(LearningRateDecay):
"""
:api_attr: imperative
Base class of learning rate decay, which is updated each epoch.
Define the common interface of an _LearningRateEpochDecay.
User should not use this class directly,
but need to use one of it's implementation. And invoke method: `epoch()` each epoch.
"""
def __init__(self, learning_rate, dtype=None):
if not isinstance(learning_rate, (float, int)):
raise TypeError(
"The type of 'learning_rate' must be 'float, int', but received %s."
% type(learning_rate))
if learning_rate >= 1.0:
raise ValueError("The initial learning rate")
self.base_lr = float(learning_rate)
self.epoch_num = -1
if dtype is None:
self.dtype = "float32"
self.learning_rate = self.create_lr_var(self.base_lr)
self.epoch()
def __call__(self):
"""
Return last computed learning rate on current epoch.
"""
return self.learning_rate
def epoch(self, epoch=None):
"""
compueted learning_rate and update it when invoked.
"""
if epoch is None:
self.epoch_num += 1
else:
self.epoch_num = epoch
self.learning_rate = self.get_lr()
if isinstance(self.learning_rate, float):
self.learning_rate = self.create_lr_var(self.learning_rate)
def get_lr(self):
raise NotImplementedError
class StepDecay(_LearningRateEpochDecay):
"""
:api_attr: imperative
Decays the learning rate of ``optimizer`` by ``decay_rate`` every ``step_size`` number of epoch.
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5
step_size = 30
decay_rate = 0.1
learning_rate = 0.5 if epoch < 30
learning_rate = 0.05 if 30 <= epoch < 60
learning_rate = 0.005 if 60 <= epoch < 90
...
Parameters:
learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
step_size (int): Period of learning rate decay..
decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
It should be less than 1.0. Default: 0.1.
Returns:
None.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
with fluid.dygraph.guard():
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = fluid.dygraph.Linear(10, 10)
input = fluid.dygraph.to_variable(x)
scheduler = fluid.dygraph.StepDecay(0.5, step_size=3)
adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters())
for epoch in range(9):
for batch_id in range(5):
out = linear(input)
loss = fluid.layers.reduce_mean(out)
adam.minimize(loss)
scheduler.epoch()
print("epoch:{}, current lr is {}" .format(epoch, adam.current_step_lr()))
# epoch:0, current lr is 0.5
# epoch:1, current lr is 0.5
# epoch:2, current lr is 0.5
# epoch:3, current lr is 0.05
# epoch:4, current lr is 0.05
# epoch:5, current lr is 0.05
# epoch:6, current lr is 0.005
# epoch:7, current lr is 0.005
# epoch:8, current lr is 0.005
"""
def __init__(self, learning_rate, step_size, decay_rate=0.1):
if not isinstance(step_size, int):
raise TypeError(
"The type of 'step_size' must be 'int', but received %s." %
type(step_size))
if decay_rate >= 1.0:
raise ValueError('decay_rate should be < 1.0.')
self.step_size = step_size
self.decay_rate = decay_rate
super(StepDecay, self).__init__(learning_rate)
def get_lr(self):
decay_rate = self.create_lr_var(self.decay_rate)
i = self.epoch_num // self.step_size
return self.base_lr * (decay_rate**i)
class MultiStepDecay(_LearningRateEpochDecay):
"""
:api_attr: imperative
Decays the learning rate of ``optimizer`` by ``decay_rate`` once ``epoch`` reaches one of the milestones.
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5
milestones = [30, 50]
decay_rate = 0.1
if epoch < 30:
learning_rate = 0.5
elif epoch < 50:
learning_rate = 0.05
else:
learning_rate = 0.005
Parameters:
learning_rate (float|int): The initial learning rate. It can be set to python float or int number. If it
milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
It should be less than 1.0. Default: 0.1.
Returns:
None.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
with fluid.dygraph.guard():
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = fluid.dygraph.Linear(10, 10)
input = fluid.dygraph.to_variable(x)
scheduler = fluid.dygraph.MultiStepDecay(0.5, milestones=[3, 5])
adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters())
for epoch in range(6):
for batch_id in range(5):
out = linear(input)
loss = fluid.layers.reduce_mean(out)
adam.minimize(loss)
scheduler.epoch()
print("epoch:{}, current lr is {}" .format(epoch, adam.current_step_lr()))
# epoch:0, current lr is 0.5
# epoch:1, current lr is 0.5
# epoch:2, current lr is 0.5
# epoch:3, current lr is 0.05
# epoch:4, current lr is 0.05
# epoch:5, current lr is 0.005
"""
def __init__(self, learning_rate, milestones, decay_rate=0.1):
if not isinstance(milestones, (tuple, list)):
raise TypeError(
"The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
% type(milestones))
if not all([
milestones[i] < milestones[i + 1]
for i in range(len(milestones) - 1)
]):
raise ValueError('The elements of milestones must be incremented')
if decay_rate >= 1.0:
raise ValueError('decay_rate should be < 1.0.')
self.milestones = milestones
self.decay_rate = decay_rate
super(MultiStepDecay, self).__init__(learning_rate)
def get_lr(self):
decay_rate = self.create_lr_var(self.decay_rate)
for i in range(len(self.milestones)):
if self.epoch_num < self.milestones[i]:
return self.base_lr * (decay_rate**i)
return self.base_lr * (decay_rate**len(self.milestones))
...@@ -498,7 +498,6 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr): ...@@ -498,7 +498,6 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
Returns: Returns:
Variable: Warm-up learning rate with the same data type as learning_rate. Variable: Warm-up learning rate with the same data type as learning_rate.
Examples: Examples:
.. code-block:: python .. code-block:: python
......
...@@ -98,8 +98,26 @@ def noam_decay(global_step, d_model, warmup_steps, learning_rate=1.0): ...@@ -98,8 +98,26 @@ def noam_decay(global_step, d_model, warmup_steps, learning_rate=1.0):
return decayed_lr return decayed_lr
class TestNoamLearningRateDecayDygraphMode(unittest.TestCase): def linear_lr_warmup(global_step, warmup_steps, start_lr, end_lr):
def test_dygraph_mode(self): linear_step = end_lr - start_lr
decayed_lr = start_lr + linear_step * (global_step / warmup_steps)
return decayed_lr
def multi_step_decay(global_step, learning_rate, milestones, decay_rate=0.1):
for i in range(len(milestones)):
if global_step < milestones[i]:
return learning_rate * math.pow(decay_rate, i)
return learning_rate * math.pow(decay_rate, len(milestones))
def step_decay(global_step, learning_rate, step_size, decay_rate=0.1):
return learning_rate * math.pow(decay_rate, global_step // step_size)
class TestLearningRateDecayDygraph(unittest.TestCase):
def test_NoamDecay(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
d_model = 0.01 d_model = 0.01
warmup_steps = 200 warmup_steps = 200
...@@ -117,6 +135,88 @@ class TestNoamLearningRateDecayDygraphMode(unittest.TestCase): ...@@ -117,6 +135,88 @@ class TestNoamLearningRateDecayDygraphMode(unittest.TestCase):
msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'. msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'.
format(step, right_result, fluid_result[0])) format(step, right_result, fluid_result[0]))
def test_LinearLrWarmup(self):
with fluid.dygraph.guard():
lr = fluid.layers.polynomial_decay(
learning_rate=1.0,
decay_steps=10,
end_learning_rate=0.0,
power=1.0)
lr = fluid.layers.linear_lr_warmup(
learning_rate=lr, warmup_steps=2, start_lr=0.0, end_lr=1.0)
right_result = [0.5, 0.9, 0.8, 0.7, 0.6]
for i in range(5):
t = lr()
self.assertTrue(
np.allclose((t.numpy())[0].item(), right_result[i]))
with self.assertRaises(TypeError):
lr = fluid.layers.linear_lr_warmup(
learning_rate="fake_lr",
warmup_steps=2,
start_lr=0.0,
end_lr=1.0)
def test_MultiStepDecay(self):
with fluid.dygraph.guard():
learning_rate = 0.5
milestones = [2, 4, 8]
decay_rate = 0.2
scheduler = fluid.dygraph.MultiStepDecay(learning_rate, milestones,
decay_rate)
for epoch in range(10):
right_result = multi_step_decay(epoch, learning_rate,
milestones, decay_rate)
fluid_result = scheduler().numpy()[0]
scheduler.epoch()
self.assertAlmostEqual(
right_result,
fluid_result,
msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'.
format(epoch, right_result, fluid_result))
with self.assertRaises(ValueError):
lr = fluid.dygraph.MultiStepDecay(learning_rate, [30, 50, 20],
0.1)
with self.assertRaises(ValueError):
lr = fluid.dygraph.MultiStepDecay(learning_rate, [20, 30, 50],
1)
def test_StepDecay(self):
with fluid.dygraph.guard():
learning_rate = 0.5
step_size = 3
decay_rate = 0.2
scheduler = fluid.dygraph.StepDecay(learning_rate, step_size,
decay_rate)
for epoch in range(10):
right_result = step_decay(epoch, learning_rate, step_size,
decay_rate)
fluid_result = scheduler().numpy()[0]
scheduler.epoch()
self.assertAlmostEqual(
right_result,
fluid_result,
msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'.
format(epoch, right_result, fluid_result))
with self.assertRaises(TypeError):
lr = fluid.dygraph.MultiStepDecay(learning_rate, "test", 0.1)
with self.assertRaises(ValueError):
lr = fluid.dygraph.MultiStepDecay(learning_rate, [20, 30, 50],
1)
with self.assertRaises(TypeError):
lr = fluid.dygraph.MultiStepDecay("test", [20, 30, 50])
with self.assertRaises(ValueError):
lr = fluid.dygraph.MultiStepDecay(2.0, [20, 30, 50])
class TestLearningRateDecay(unittest.TestCase): class TestLearningRateDecay(unittest.TestCase):
def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs): def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs):
...@@ -171,31 +271,26 @@ class TestLearningRateDecay(unittest.TestCase): ...@@ -171,31 +271,26 @@ class TestLearningRateDecay(unittest.TestCase):
(natural_exp_decay, layers.natural_exp_decay, common_kwargs_false), (natural_exp_decay, layers.natural_exp_decay, common_kwargs_false),
(inverse_time_decay, layers.inverse_time_decay, common_kwargs_true), (inverse_time_decay, layers.inverse_time_decay, common_kwargs_true),
(inverse_time_decay, layers.inverse_time_decay, (inverse_time_decay, layers.inverse_time_decay,
common_kwargs_false), common_kwargs_false), (polynomial_decay, layers.polynomial_decay, {
(polynomial_decay, layers.polynomial_decay, { "learning_rate": 1.0,
"learning_rate": 1.0, "decay_steps": 5,
"decay_steps": 5, "cycle": True
"cycle": True }), (polynomial_decay, layers.polynomial_decay, {
}), "learning_rate": 1.0,
(polynomial_decay, layers.polynomial_decay, { "decay_steps": 5,
"learning_rate": 1.0, "cycle": False
"decay_steps": 5, }), (piecewise_decay, layers.piecewise_decay, {
"cycle": False "boundaries": [3, 6, 9],
}), "values": [0.1, 0.2, 0.3, 0.4]
(piecewise_decay, layers.piecewise_decay, { }), (cosine_decay, layers.cosine_decay, {
"boundaries": [3, 6, 9], "learning_rate": 0.1,
"values": [0.1, 0.2, 0.3, 0.4] "step_each_epoch": 100,
}), "epochs": 120
(cosine_decay, layers.cosine_decay, { }), (noam_decay, layers.noam_decay, {
"learning_rate": 0.1, "d_model": 0.01,
"step_each_epoch": 100, "warmup_steps": 200,
"epochs": 120 "learning_rate": 2.0
}), })
(noam_decay, layers.noam_decay, {
"d_model": 0.01,
"warmup_steps": 200,
"learning_rate": 2.0
}),
] ]
for py_decay_fn, fluid_decay_fn, kwargs in decay_fns: for py_decay_fn, fluid_decay_fn, kwargs in decay_fns:
...@@ -207,13 +302,7 @@ class TestLearningRateDecay(unittest.TestCase): ...@@ -207,13 +302,7 @@ class TestLearningRateDecay(unittest.TestCase):
self.check_decay(py_decay_fn, fluid_decay_fn, kwargs) self.check_decay(py_decay_fn, fluid_decay_fn, kwargs)
def linear_lr_warmup(global_step, warmup_steps, start_lr, end_lr): class TestLinearWamrupLearningRateDecay(unittest.TestCase):
linear_step = end_lr - start_lr
decayed_lr = start_lr + linear_step * (global_step / warmup_steps)
return decayed_lr
class TestLinearWamrupLearningRateDecay(TestLearningRateDecay):
def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn, def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn,
kwargs): kwargs):
main_prog = fluid.Program() main_prog = fluid.Program()
...@@ -304,37 +393,6 @@ class TestLinearWamrupLearningRateDecayWithScalarInput(unittest.TestCase): ...@@ -304,37 +393,6 @@ class TestLinearWamrupLearningRateDecayWithScalarInput(unittest.TestCase):
run_places(lr, start_lr, end_lr) run_places(lr, start_lr, end_lr)
class TestLinearWamrupLearningRateDecayDygraphMode(unittest.TestCase):
def test_dygraph_mode(self):
with fluid.dygraph.guard():
lr = fluid.layers.polynomial_decay(
learning_rate=1.0,
decay_steps=10,
end_learning_rate=0.0,
power=1.0)
lr = fluid.layers.linear_lr_warmup(
learning_rate=lr, warmup_steps=2, start_lr=0.0, end_lr=1.0)
right_result = [0.5, 0.9, 0.8, 0.7, 0.6]
for i in range(5):
t = lr()
self.assertTrue(
np.allclose((t.numpy())[0].item(), right_result[i]))
class TestLinearWamrupLearningRateDecayDygraphModeTypeCheck(unittest.TestCase):
def test_dygraph_mode(self):
with fluid.dygraph.guard():
with self.assertRaises(TypeError):
lr = fluid.layers.linear_lr_warmup(
learning_rate="fake_lr",
warmup_steps=2,
start_lr=0.0,
end_lr=1.0)
def reduce_lr_on_plateau(decay_rate, threshold, cooldown, patience, m, n, loss, def reduce_lr_on_plateau(decay_rate, threshold, cooldown, patience, m, n, loss,
var_list): var_list):
def is_better(current, best, m, n): def is_better(current, best, m, n):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册