diff --git a/python/paddle/fluid/dygraph/learning_rate_scheduler.py b/python/paddle/fluid/dygraph/learning_rate_scheduler.py index f54654f8adf00ed8ca0bae508ed7905b3359c737..864a00e2e14fe5342c806c4271c070fb45141266 100644 --- a/python/paddle/fluid/dygraph/learning_rate_scheduler.py +++ b/python/paddle/fluid/dygraph/learning_rate_scheduler.py @@ -553,3 +553,92 @@ class NoamDecay(LearningRateDecay): b = self.create_lr_var((self.warmup_steps**-1.5) * self.step_num) lr_value = (self.d_model**-0.5) * layers.elementwise_min(a, b) return lr_value + + +class LinearLrWarmup(LearningRateDecay): + """ + This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling. + For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks `_ + + When global_step < warmup_steps, learning rate is updated as: + + .. code-block:: text + + linear_step = end_lr - start_lr + lr = start_lr + linear_step * (global_step / warmup_steps) + + where start_lr is the initial learning rate, and end_lr is the final learning rate; + + When global_step >= warmup_steps, learning rate is updated as: + + .. code-block:: text + + lr = learning_rate + + where lr is the learning_rate after warm-up. + + Args: + learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32. + warmup_steps (int): Steps for warm up. + start_lr (float): Initial learning rate of warm up. + end_lr (float): Final learning rate of warm up. + begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0. + step(int, optional): The step size used to calculate the new global_step in the description above. + The defalult value is 1. + dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as + 'float32', 'float64'. The default value is 'float32'. + + Returns: + Variable: Warm-up learning rate with the same data type as learning_rate. + + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + + learning_rate = 0.1 + warmup_steps = 50 + start_lr = 1. / 3. + end_lr = 0.1 + + with fluid.dygraph.guard(): + lr_decay = fluid.dygraph.LinearLrWarmup( learning_rate, warmup_steps, start_lr, end_lr) + + + """ + + def __init__(self, + learning_rate, + warmup_steps, + start_lr, + end_lr, + begin=1, + step=1, + dtype='float32'): + super(LinearLrWarmup, self).__init__(begin, step, dtype) + type_check = isinstance(learning_rate, float) or isinstance( + learning_rate, int) or isinstance(learning_rate, LearningRateDecay) + if not type_check: + raise TypeError( + "the type of learning_rate should be [int, float or LearningRateDecay], the current type is {}". + format(learning_rate)) + self.learning_rate = learning_rate + self.warmup_steps = warmup_steps + assert (end_lr > start_lr, + "end_lr {} MUST GREATER than start_lr {}".format(end_lr, + start_lr)) + self.lr_ratio_before_warmup = ( + float(end_lr) - float(start_lr)) / float(warmup_steps) + + def step(self): + base_lr = self.learning_rate + if isinstance(self.learning_rate, LearningRateDecay): + base_lr = base_lr() + + from .. import layers + if self.step_num < self.warmup_steps: + return self.lr_ratio_before_warmup * self.step_num + else: + return base_lr diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index 2e397412767c244fa299205f25ee53d1cfb8361c..65d837baa1a8b32e8a24b04de677bcc6505c3610 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -519,23 +519,29 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr): linear_step = float(end_lr) - float(start_lr) with default_main_program()._lr_schedule_guard(): - lr = tensor.create_global_var( - shape=[1], - value=0.0, - dtype=dtype, - persistable=True, - name="learning_rate_warmup") - - global_step = _decay_step_counter() - - with control_flow.Switch() as switch: - with switch.case(global_step < warmup_steps): - decayed_lr = start_lr + linear_step * (global_step / - float(warmup_steps)) - tensor.assign(decayed_lr, lr) - with switch.default(): - if not isinstance(learning_rate, Variable): - learning_rate = tensor.fill_constant( - shape=[1], dtype=dtype, value=float(learning_rate)) - tensor.assign(learning_rate, lr) - return lr + + if imperative_base.enabled(): + lr = imperate_lr.LinearLrWarmup(learning_rate, warmup_steps, + start_lr, end_lr) + return lr + else: + lr = tensor.create_global_var( + shape=[1], + value=0.0, + dtype=dtype, + persistable=True, + name="learning_rate_warmup") + + global_step = _decay_step_counter() + + with control_flow.Switch() as switch: + with switch.case(global_step < warmup_steps): + decayed_lr = start_lr + linear_step * (global_step / + float(warmup_steps)) + tensor.assign(decayed_lr, lr) + with switch.default(): + if not isinstance(learning_rate, Variable): + learning_rate = tensor.fill_constant( + shape=[1], dtype=dtype, value=float(learning_rate)) + tensor.assign(learning_rate, lr) + return lr diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py index 88d9919f59619cabad2e4ceca839e4a13d2cfd23..e3f79448e7394f1148416a70b08c2bdb128905ce 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py @@ -264,5 +264,35 @@ class TestLinearWamrupLearningRateDecayWithScalarInput(unittest.TestCase): run_places(lr, start_lr, end_lr) +class TestLinearWamrupLearningRateDecayDygraphMode(unittest.TestCase): + def test_dygraph_mode(self): + with fluid.dygraph.guard(): + lr = fluid.layers.polynomial_decay( + learning_rate=1.0, + decay_steps=10, + end_learning_rate=0.0, + power=1.0) + lr = fluid.layers.linear_lr_warmup( + learning_rate=lr, warmup_steps=2, start_lr=0.0, end_lr=1.0) + + right_result = [0.5, 0.9, 0.8, 0.7, 0.6] + for i in range(5): + + t = lr() + + self.assertEqual(t[0], right_result[i]) + + +class TestLinearWamrupLearningRateDecayDygraphModeTypeCheck(unittest.TestCase): + def test_dygraph_mode(self): + with fluid.dygraph.guard(): + with self.assertRaises(TypeError): + lr = fluid.layers.linear_lr_warmup( + learning_rate="fake_lr", + warmup_steps=2, + start_lr=0.0, + end_lr=1.0) + + if __name__ == '__main__': unittest.main()