From d171f373a333c46299126d45ede4680d8710160b Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Wed, 1 Jul 2020 23:15:14 +0800 Subject: [PATCH] [CHERR-PICK1.8]add base class LearningRateEpochDecay, and MultiStepDecay, StepDecay (#25277) * CHERR-PICK1.8,add base class of LearningRateEpochDecay, and API: MultiStepDecay, and API: StepDecay,test=release/1.8 * fix unittest to add coverage,test=develop --- .../fluid/dygraph/learning_rate_scheduler.py | 236 +++++++++++++++++- .../fluid/layers/learning_rate_scheduler.py | 1 - .../unittests/test_learning_rate_scheduler.py | 188 +++++++++----- 3 files changed, 356 insertions(+), 69 deletions(-) diff --git a/python/paddle/fluid/dygraph/learning_rate_scheduler.py b/python/paddle/fluid/dygraph/learning_rate_scheduler.py index 7e4e50bc476..687767d939f 100644 --- a/python/paddle/fluid/dygraph/learning_rate_scheduler.py +++ b/python/paddle/fluid/dygraph/learning_rate_scheduler.py @@ -23,7 +23,7 @@ from ..data_feeder import check_type __all__ = [ 'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay', 'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay', 'LinearLrWarmup', - 'ReduceLROnPlateau' + 'ReduceLROnPlateau', 'StepDecay', 'MultiStepDecay' ] @@ -72,6 +72,8 @@ class LearningRateDecay(object): class PiecewiseDecay(LearningRateDecay): """ + :api_attr: imperative + Piecewise decay scheduler. The algorithm can be described as the code below. @@ -131,6 +133,8 @@ class PiecewiseDecay(LearningRateDecay): class NaturalExpDecay(LearningRateDecay): """ + :api_attr: imperative + Applies natural exponential decay to the initial learning rate. The algorithm can be described as following. @@ -183,7 +187,6 @@ class NaturalExpDecay(LearningRateDecay): staircase=True), parameter_list=emb.parameters()) - """ def __init__(self, @@ -213,6 +216,8 @@ class NaturalExpDecay(LearningRateDecay): class ExponentialDecay(LearningRateDecay): """ + :api_attr: imperative + Applies exponential decay to the learning rate. The algorithm can be described as following. @@ -293,6 +298,8 @@ class ExponentialDecay(LearningRateDecay): class InverseTimeDecay(LearningRateDecay): """ + :api_attr: imperative + Applies inverse time decay to the initial learning rate. The algorithm can be described as following. @@ -369,6 +376,8 @@ class InverseTimeDecay(LearningRateDecay): class PolynomialDecay(LearningRateDecay): """ + :api_attr: imperative + Applies polynomial decay to the initial learning rate. The algorithm can be described as following. @@ -461,6 +470,8 @@ class PolynomialDecay(LearningRateDecay): class CosineDecay(LearningRateDecay): """ + :api_attr: imperative + Applies cosine decay to the learning rate. The algorithm can be described as following. @@ -517,6 +528,8 @@ class CosineDecay(LearningRateDecay): class NoamDecay(LearningRateDecay): """ + :api_attr: imperative + Applies Noam decay to the initial learning rate. The algorithm can be described as following. @@ -582,6 +595,8 @@ class NoamDecay(LearningRateDecay): class LinearLrWarmup(LearningRateDecay): """ + :api_attr: imperative + This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling. For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks `_ @@ -670,6 +685,8 @@ class LinearLrWarmup(LearningRateDecay): class ReduceLROnPlateau(LearningRateDecay): """ + :api_attr: imperative + Reduce learning rate when ``loss`` has stopped descending. Models often benefit from reducing the learning rate by 2 to 10 times once model performance has no longer improvement. @@ -774,7 +791,6 @@ class ReduceLROnPlateau(LearningRateDecay): raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') self.threshold_mode = threshold_mode - check_type(learning_rate, 'learning_rate', (float, int, Variable), 'ReduceLROnPlateau') if isinstance(learning_rate, (float, int)): @@ -856,3 +872,217 @@ class ReduceLROnPlateau(LearningRateDecay): else: return current > best + self.threshold + + +class _LearningRateEpochDecay(LearningRateDecay): + """ + :api_attr: imperative + + Base class of learning rate decay, which is updated each epoch. + + Define the common interface of an _LearningRateEpochDecay. + User should not use this class directly, + but need to use one of it's implementation. And invoke method: `epoch()` each epoch. + """ + + def __init__(self, learning_rate, dtype=None): + if not isinstance(learning_rate, (float, int)): + raise TypeError( + "The type of 'learning_rate' must be 'float, int', but received %s." + % type(learning_rate)) + if learning_rate >= 1.0: + raise ValueError("The initial learning rate") + + self.base_lr = float(learning_rate) + + self.epoch_num = -1 + if dtype is None: + self.dtype = "float32" + self.learning_rate = self.create_lr_var(self.base_lr) + + self.epoch() + + def __call__(self): + """ + Return last computed learning rate on current epoch. + """ + return self.learning_rate + + def epoch(self, epoch=None): + """ + compueted learning_rate and update it when invoked. + """ + if epoch is None: + self.epoch_num += 1 + else: + self.epoch_num = epoch + + self.learning_rate = self.get_lr() + if isinstance(self.learning_rate, float): + self.learning_rate = self.create_lr_var(self.learning_rate) + + def get_lr(self): + raise NotImplementedError + + +class StepDecay(_LearningRateEpochDecay): + """ + :api_attr: imperative + + Decays the learning rate of ``optimizer`` by ``decay_rate`` every ``step_size`` number of epoch. + + The algorithm can be described as the code below. + + .. code-block:: text + + learning_rate = 0.5 + step_size = 30 + decay_rate = 0.1 + + learning_rate = 0.5 if epoch < 30 + learning_rate = 0.05 if 30 <= epoch < 60 + learning_rate = 0.005 if 60 <= epoch < 90 + ... + + Parameters: + learning_rate (float|int): The initial learning rate. It can be set to python float or int number. + step_size (int): Period of learning rate decay.. + decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` . + It should be less than 1.0. Default: 0.1. + + Returns: + None. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + with fluid.dygraph.guard(): + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = fluid.dygraph.Linear(10, 10) + input = fluid.dygraph.to_variable(x) + scheduler = fluid.dygraph.StepDecay(0.5, step_size=3) + adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters()) + + for epoch in range(9): + for batch_id in range(5): + out = linear(input) + loss = fluid.layers.reduce_mean(out) + adam.minimize(loss) + scheduler.epoch() + + print("epoch:{}, current lr is {}" .format(epoch, adam.current_step_lr())) + # epoch:0, current lr is 0.5 + # epoch:1, current lr is 0.5 + # epoch:2, current lr is 0.5 + # epoch:3, current lr is 0.05 + # epoch:4, current lr is 0.05 + # epoch:5, current lr is 0.05 + # epoch:6, current lr is 0.005 + # epoch:7, current lr is 0.005 + # epoch:8, current lr is 0.005 + + """ + + def __init__(self, learning_rate, step_size, decay_rate=0.1): + if not isinstance(step_size, int): + raise TypeError( + "The type of 'step_size' must be 'int', but received %s." % + type(step_size)) + if decay_rate >= 1.0: + raise ValueError('decay_rate should be < 1.0.') + + self.step_size = step_size + self.decay_rate = decay_rate + super(StepDecay, self).__init__(learning_rate) + + def get_lr(self): + decay_rate = self.create_lr_var(self.decay_rate) + i = self.epoch_num // self.step_size + return self.base_lr * (decay_rate**i) + + +class MultiStepDecay(_LearningRateEpochDecay): + """ + :api_attr: imperative + + Decays the learning rate of ``optimizer`` by ``decay_rate`` once ``epoch`` reaches one of the milestones. + + The algorithm can be described as the code below. + + .. code-block:: text + + learning_rate = 0.5 + milestones = [30, 50] + decay_rate = 0.1 + if epoch < 30: + learning_rate = 0.5 + elif epoch < 50: + learning_rate = 0.05 + else: + learning_rate = 0.005 + + Parameters: + learning_rate (float|int): The initial learning rate. It can be set to python float or int number. If it + milestones (tuple|list): List or tuple of each boundaries. Must be increasing. + decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` . + It should be less than 1.0. Default: 0.1. + + Returns: + None. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + with fluid.dygraph.guard(): + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = fluid.dygraph.Linear(10, 10) + input = fluid.dygraph.to_variable(x) + scheduler = fluid.dygraph.MultiStepDecay(0.5, milestones=[3, 5]) + adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters()) + + for epoch in range(6): + for batch_id in range(5): + out = linear(input) + loss = fluid.layers.reduce_mean(out) + adam.minimize(loss) + scheduler.epoch() + + print("epoch:{}, current lr is {}" .format(epoch, adam.current_step_lr())) + # epoch:0, current lr is 0.5 + # epoch:1, current lr is 0.5 + # epoch:2, current lr is 0.5 + # epoch:3, current lr is 0.05 + # epoch:4, current lr is 0.05 + # epoch:5, current lr is 0.005 + + """ + + def __init__(self, learning_rate, milestones, decay_rate=0.1): + if not isinstance(milestones, (tuple, list)): + raise TypeError( + "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s." + % type(milestones)) + + if not all([ + milestones[i] < milestones[i + 1] + for i in range(len(milestones) - 1) + ]): + raise ValueError('The elements of milestones must be incremented') + if decay_rate >= 1.0: + raise ValueError('decay_rate should be < 1.0.') + + self.milestones = milestones + self.decay_rate = decay_rate + super(MultiStepDecay, self).__init__(learning_rate) + + def get_lr(self): + decay_rate = self.create_lr_var(self.decay_rate) + for i in range(len(self.milestones)): + if self.epoch_num < self.milestones[i]: + return self.base_lr * (decay_rate**i) + + return self.base_lr * (decay_rate**len(self.milestones)) diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index 0115b398e66..ec062b233f4 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -498,7 +498,6 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr): Returns: Variable: Warm-up learning rate with the same data type as learning_rate. - Examples: .. code-block:: python diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py index 8b66035c57a..1bf8acae553 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py @@ -98,8 +98,26 @@ def noam_decay(global_step, d_model, warmup_steps, learning_rate=1.0): return decayed_lr -class TestNoamLearningRateDecayDygraphMode(unittest.TestCase): - def test_dygraph_mode(self): +def linear_lr_warmup(global_step, warmup_steps, start_lr, end_lr): + linear_step = end_lr - start_lr + decayed_lr = start_lr + linear_step * (global_step / warmup_steps) + return decayed_lr + + +def multi_step_decay(global_step, learning_rate, milestones, decay_rate=0.1): + for i in range(len(milestones)): + if global_step < milestones[i]: + return learning_rate * math.pow(decay_rate, i) + + return learning_rate * math.pow(decay_rate, len(milestones)) + + +def step_decay(global_step, learning_rate, step_size, decay_rate=0.1): + return learning_rate * math.pow(decay_rate, global_step // step_size) + + +class TestLearningRateDecayDygraph(unittest.TestCase): + def test_NoamDecay(self): with fluid.dygraph.guard(): d_model = 0.01 warmup_steps = 200 @@ -117,6 +135,88 @@ class TestNoamLearningRateDecayDygraphMode(unittest.TestCase): msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'. format(step, right_result, fluid_result[0])) + def test_LinearLrWarmup(self): + with fluid.dygraph.guard(): + lr = fluid.layers.polynomial_decay( + learning_rate=1.0, + decay_steps=10, + end_learning_rate=0.0, + power=1.0) + lr = fluid.layers.linear_lr_warmup( + learning_rate=lr, warmup_steps=2, start_lr=0.0, end_lr=1.0) + + right_result = [0.5, 0.9, 0.8, 0.7, 0.6] + for i in range(5): + + t = lr() + + self.assertTrue( + np.allclose((t.numpy())[0].item(), right_result[i])) + + with self.assertRaises(TypeError): + lr = fluid.layers.linear_lr_warmup( + learning_rate="fake_lr", + warmup_steps=2, + start_lr=0.0, + end_lr=1.0) + + def test_MultiStepDecay(self): + with fluid.dygraph.guard(): + learning_rate = 0.5 + milestones = [2, 4, 8] + decay_rate = 0.2 + scheduler = fluid.dygraph.MultiStepDecay(learning_rate, milestones, + decay_rate) + for epoch in range(10): + right_result = multi_step_decay(epoch, learning_rate, + milestones, decay_rate) + fluid_result = scheduler().numpy()[0] + scheduler.epoch() + self.assertAlmostEqual( + right_result, + fluid_result, + msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'. + format(epoch, right_result, fluid_result)) + + with self.assertRaises(ValueError): + lr = fluid.dygraph.MultiStepDecay(learning_rate, [30, 50, 20], + 0.1) + + with self.assertRaises(ValueError): + lr = fluid.dygraph.MultiStepDecay(learning_rate, [20, 30, 50], + 1) + + def test_StepDecay(self): + with fluid.dygraph.guard(): + learning_rate = 0.5 + step_size = 3 + decay_rate = 0.2 + scheduler = fluid.dygraph.StepDecay(learning_rate, step_size, + decay_rate) + for epoch in range(10): + right_result = step_decay(epoch, learning_rate, step_size, + decay_rate) + fluid_result = scheduler().numpy()[0] + scheduler.epoch() + self.assertAlmostEqual( + right_result, + fluid_result, + msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'. + format(epoch, right_result, fluid_result)) + + with self.assertRaises(TypeError): + lr = fluid.dygraph.MultiStepDecay(learning_rate, "test", 0.1) + + with self.assertRaises(ValueError): + lr = fluid.dygraph.MultiStepDecay(learning_rate, [20, 30, 50], + 1) + + with self.assertRaises(TypeError): + lr = fluid.dygraph.MultiStepDecay("test", [20, 30, 50]) + + with self.assertRaises(ValueError): + lr = fluid.dygraph.MultiStepDecay(2.0, [20, 30, 50]) + class TestLearningRateDecay(unittest.TestCase): def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs): @@ -171,31 +271,26 @@ class TestLearningRateDecay(unittest.TestCase): (natural_exp_decay, layers.natural_exp_decay, common_kwargs_false), (inverse_time_decay, layers.inverse_time_decay, common_kwargs_true), (inverse_time_decay, layers.inverse_time_decay, - common_kwargs_false), - (polynomial_decay, layers.polynomial_decay, { - "learning_rate": 1.0, - "decay_steps": 5, - "cycle": True - }), - (polynomial_decay, layers.polynomial_decay, { - "learning_rate": 1.0, - "decay_steps": 5, - "cycle": False - }), - (piecewise_decay, layers.piecewise_decay, { - "boundaries": [3, 6, 9], - "values": [0.1, 0.2, 0.3, 0.4] - }), - (cosine_decay, layers.cosine_decay, { - "learning_rate": 0.1, - "step_each_epoch": 100, - "epochs": 120 - }), - (noam_decay, layers.noam_decay, { - "d_model": 0.01, - "warmup_steps": 200, - "learning_rate": 2.0 - }), + common_kwargs_false), (polynomial_decay, layers.polynomial_decay, { + "learning_rate": 1.0, + "decay_steps": 5, + "cycle": True + }), (polynomial_decay, layers.polynomial_decay, { + "learning_rate": 1.0, + "decay_steps": 5, + "cycle": False + }), (piecewise_decay, layers.piecewise_decay, { + "boundaries": [3, 6, 9], + "values": [0.1, 0.2, 0.3, 0.4] + }), (cosine_decay, layers.cosine_decay, { + "learning_rate": 0.1, + "step_each_epoch": 100, + "epochs": 120 + }), (noam_decay, layers.noam_decay, { + "d_model": 0.01, + "warmup_steps": 200, + "learning_rate": 2.0 + }) ] for py_decay_fn, fluid_decay_fn, kwargs in decay_fns: @@ -207,13 +302,7 @@ class TestLearningRateDecay(unittest.TestCase): self.check_decay(py_decay_fn, fluid_decay_fn, kwargs) -def linear_lr_warmup(global_step, warmup_steps, start_lr, end_lr): - linear_step = end_lr - start_lr - decayed_lr = start_lr + linear_step * (global_step / warmup_steps) - return decayed_lr - - -class TestLinearWamrupLearningRateDecay(TestLearningRateDecay): +class TestLinearWamrupLearningRateDecay(unittest.TestCase): def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn, kwargs): main_prog = fluid.Program() @@ -304,37 +393,6 @@ class TestLinearWamrupLearningRateDecayWithScalarInput(unittest.TestCase): run_places(lr, start_lr, end_lr) -class TestLinearWamrupLearningRateDecayDygraphMode(unittest.TestCase): - def test_dygraph_mode(self): - with fluid.dygraph.guard(): - lr = fluid.layers.polynomial_decay( - learning_rate=1.0, - decay_steps=10, - end_learning_rate=0.0, - power=1.0) - lr = fluid.layers.linear_lr_warmup( - learning_rate=lr, warmup_steps=2, start_lr=0.0, end_lr=1.0) - - right_result = [0.5, 0.9, 0.8, 0.7, 0.6] - for i in range(5): - - t = lr() - - self.assertTrue( - np.allclose((t.numpy())[0].item(), right_result[i])) - - -class TestLinearWamrupLearningRateDecayDygraphModeTypeCheck(unittest.TestCase): - def test_dygraph_mode(self): - with fluid.dygraph.guard(): - with self.assertRaises(TypeError): - lr = fluid.layers.linear_lr_warmup( - learning_rate="fake_lr", - warmup_steps=2, - start_lr=0.0, - end_lr=1.0) - - def reduce_lr_on_plateau(decay_rate, threshold, cooldown, patience, m, n, loss, var_list): def is_better(current, best, m, n): -- GitLab