diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index 8c6383cd6ef5230e350cc68bf269a641f7014339..04a0d47e47c86ba82712a162e8ef133418bf0886 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -537,6 +537,18 @@ class TestLRScheduler(unittest.TestCase): self._test_dygraph(python_func, paddle_api, kwarg, place) paddle.enable_static() + def test_linear_warmp(self): + natural_lr = paddle.optimizer.lr.NaturalExpDecay( + learning_rate=0.5, gamma=0.1) + natural_lr_warmup = paddle.optimizer.lr.LinearWarmup( + learning_rate=natural_lr, warmup_steps=10, start_lr=0.0, end_lr=0.1) + for idx in range(30): + if idx >= 10: + self.assertEqual(natural_lr_warmup.get_lr(), + natural_lr.get_lr()) + natural_lr.step() + natural_lr_warmup.step() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 5085911ce927a319c191b91cd6b48af64a50a05d..484b4fb7246a766ea4ee984980b0fc0ffecabf4e 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -786,9 +786,8 @@ class LinearWarmup(LRScheduler): self.last_epoch) / float(self.warmup_steps) + self.start_lr else: if isinstance(self.learning_rate, LRScheduler): - lr_value = self.learning_rate() - self.learning_rate.step() - return lr_value + self.learning_rate.step(self.last_epoch - self.warmup_steps) + return self.learning_rate() return self.learning_rate