From 93ad33945047213040ac2ada102b69365430b8f2 Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Thu, 9 Jul 2020 21:06:53 +0800 Subject: [PATCH] add new API:LambdaDecay,test=develop (#24801) add new API:LambdaDecay --- .../fluid/dygraph/learning_rate_scheduler.py | 69 ++++++++++++++++++- .../unittests/test_learning_rate_scheduler.py | 27 ++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/learning_rate_scheduler.py b/python/paddle/fluid/dygraph/learning_rate_scheduler.py index 1bfd4c77a26..a1adef52c04 100644 --- a/python/paddle/fluid/dygraph/learning_rate_scheduler.py +++ b/python/paddle/fluid/dygraph/learning_rate_scheduler.py @@ -23,7 +23,7 @@ from ..data_feeder import check_type __all__ = [ 'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay', 'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay', 'LinearLrWarmup', - 'ReduceLROnPlateau', 'StepDecay', 'MultiStepDecay' + 'ReduceLROnPlateau', 'StepDecay', 'MultiStepDecay', 'LambdaDecay' ] @@ -1087,3 +1087,70 @@ class MultiStepDecay(_LearningRateEpochDecay): return self.base_lr * (decay_rate**i) return self.base_lr * (decay_rate**len(self.milestones)) + + +class LambdaDecay(_LearningRateEpochDecay): + """ + :api_attr: imperative + + Sets the learning rate of ``optimizer`` to the initial lr times a multiplicative factor, and this multiplicative + factor is computed by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` . + + The algorithm can be described as the code below. + + .. code-block:: text + + learning_rate = 0.5 # init learning_rate + lr_lambda = lambda epoch: 0.95 ** epoch + + learning_rate = 0.5 # epoch 0 + learning_rate = 0.475 # epoch 1 + learning_rate = 0.45125 # epoch 2 + + Parameters: + learning_rate (float|int): The initial learning rate. It can be set to python float or int number. + lr_lambda (function): A function which computes a multiplicative factor given an integer parameter ``epoch`` , and + then multiply the initial learning rate by this multiplicative factor. + + Returns: + None. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + with fluid.dygraph.guard(): + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = fluid.dygraph.Linear(10, 10) + input = fluid.dygraph.to_variable(x) + scheduler = fluid.dygraph.LambdaDecay(0.5, lr_lambda=lambda x: 0.95**x) + adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters()) + + for epoch in range(6): + for batch_id in range(5): + out = linear(input) + loss = fluid.layers.reduce_mean(out) + adam.minimize(loss) + scheduler.epoch() + + print("epoch:%d, current lr is %f" .format(epoch, adam.current_step_lr())) + # epoch:0, current lr is 0.5 + # epoch:1, current lr is 0.475 + # epoch:2, current lr is 0.45125 + + """ + + def __init__(self, learning_rate, lr_lambda): + if not callable(lr_lambda): + raise TypeError( + "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s." + % type(lr_lambda)) + + self.lr_lambda = lr_lambda + super(LambdaDecay, self).__init__(learning_rate) + + def get_lr(self): + base_lr = self.create_lr_var(self.base_lr) + + return self.base_lr * self.lr_lambda(self.epoch_num) diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py index 3b19b7bb10e..73e1f148fbb 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py @@ -116,6 +116,10 @@ def step_decay(global_step, learning_rate, step_size, decay_rate=0.1): return learning_rate * math.pow(decay_rate, global_step // step_size) +def lambda_decay(global_step, learning_rate, lr_lambda): + return learning_rate * lr_lambda(global_step) + + class TestLearningRateDecayDygraph(unittest.TestCase): def test_NoamDecay(self): with fluid.dygraph.guard(): @@ -217,6 +221,29 @@ class TestLearningRateDecayDygraph(unittest.TestCase): lr = fluid.dygraph.MultiStepDecay(learning_rate, [20, 30, 50], 1) + def test_LambdaDecay(self): + with fluid.dygraph.guard(): + learning_rate = 0.5 + lr_lambda = lambda x: 0.95**x + scheduler = fluid.dygraph.LambdaDecay(learning_rate, lr_lambda) + + linear = fluid.dygraph.nn.Linear(10, 10) + adam = fluid.optimizer.Adam( + scheduler, parameter_list=linear.parameters()) + + for epoch in range(30): + right_result = lambda_decay(epoch, learning_rate, lr_lambda) + fluid_result = scheduler().numpy()[0] + scheduler.epoch() + self.assertAlmostEqual( + right_result, + fluid_result, + msg='Failed lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'. + format(epoch, right_result, fluid_result)) + + with self.assertRaises(TypeError): + lr = fluid.dygraph.LambdaDecay(learning_rate, "test") + class TestLearningRateDecay(unittest.TestCase): def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs): -- GitLab