diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 5187a651b97830ab24fe04e8fa4ce452fea65510..12fcd90c67dcea1b3c7fd0d030b0b0e6a6b74e71 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -553,6 +553,52 @@ class Optimizer: stop_gradient=True, ) + @framework.dygraph_only + def set_lr_scheduler(self, scheduler): + """ + :api_attr: imperative + + Set the LRScheduler of the learning rate manually in the optimizer. If the optimizer already used LRScheduler previously, + this API will set it be the new one. + + Args: + scheduler (LRScheduler): the LRScheduler of learning rate + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + linear = paddle.nn.Linear(10, 10) + + adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters()) + + # set learning rate manually by class LRScheduler + scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2,4,6], gamma=0.8) + adam.set_lr_scheduler(scheduler) + lr = adam.get_lr() + print("current lr is {}".format(lr)) + # current lr is 0.5 + + # set learning rate manually by another LRScheduler + scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.1, step_size=5, gamma=0.6) + adam.set_lr_scheduler(scheduler) + lr = adam.get_lr() + print("current lr is {}".format(lr)) + # current lr is 0.1 + + """ + from paddle.optimizer.lr import LRScheduler + + if not isinstance(scheduler, LRScheduler): + raise TypeError( + "The type of 'scheduler' in optimizer.set_lr_schduler must be LRScheduler, but received %s." + % (type(scheduler)) + ) + self._learning_rate = scheduler + def get_lr(self): """ Get current learning rate of optimizer. diff --git a/test/legacy_test/test_imperative_optimizer_v2.py b/test/legacy_test/test_imperative_optimizer_v2.py index 5348a410e505602ae455c38dd0cb484f578751cb..71f3ac1941fbc4290fc3dae9e755ecdeed0dd843 100644 --- a/test/legacy_test/test_imperative_optimizer_v2.py +++ b/test/legacy_test/test_imperative_optimizer_v2.py @@ -656,6 +656,42 @@ class TestOptimizerLearningRate(unittest.TestCase): ) adam.set_lr(0.01) + def test_set_lr_scheduler(self): + with fluid.dygraph.guard(): + a = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") + + linear = paddle.nn.Linear(10, 10) + + a = fluid.dygraph.to_variable(a) + + b = linear(a) + + loss = paddle.mean(b) + + adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters()) + + # float to LRScheduler + scheduler = paddle.optimizer.lr.StepDecay( + learning_rate=0.2, step_size=5, gamma=0.6 + ) + adam.set_lr_scheduler(scheduler) + adam.minimize(loss) + lr = adam.get_lr() + np.testing.assert_allclose(lr, 0.2, rtol=1e-06, atol=0.0) + + # LRScheduler to another LRScheduler + scheduler = paddle.optimizer.lr.MultiStepDecay( + learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8 + ) + adam.set_lr_scheduler(scheduler) + adam.minimize(loss) + lr = adam.get_lr() + np.testing.assert_allclose(lr, 0.5, rtol=1e-06, atol=0.0) + + with self.assertRaises(TypeError): + scheduler_var = paddle.fluid.dygraph.StepDecay(0.5, step_size=3) + adam.set_lr_scheduler(scheduler_var) + class TestImperativeMomentumOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list):