From 511e204e620f3c6e3df2018746c52c5bf2386a59 Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Thu, 25 Mar 2021 11:24:01 +0800 Subject: [PATCH] LRScheduler.get_lr should not update lr in LinearWarmup (#31843) --- .../fluid/tests/unittests/test_lr_scheduler.py | 12 ++++++++++++ python/paddle/optimizer/lr.py | 5 ++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index 8c6383cd6ef..04a0d47e47c 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -537,6 +537,18 @@ class TestLRScheduler(unittest.TestCase): self._test_dygraph(python_func, paddle_api, kwarg, place) paddle.enable_static() + def test_linear_warmp(self): + natural_lr = paddle.optimizer.lr.NaturalExpDecay( + learning_rate=0.5, gamma=0.1) + natural_lr_warmup = paddle.optimizer.lr.LinearWarmup( + learning_rate=natural_lr, warmup_steps=10, start_lr=0.0, end_lr=0.1) + for idx in range(30): + if idx >= 10: + self.assertEqual(natural_lr_warmup.get_lr(), + natural_lr.get_lr()) + natural_lr.step() + natural_lr_warmup.step() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 5085911ce92..484b4fb7246 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -786,9 +786,8 @@ class LinearWarmup(LRScheduler): self.last_epoch) / float(self.warmup_steps) + self.start_lr else: if isinstance(self.learning_rate, LRScheduler): - lr_value = self.learning_rate() - self.learning_rate.step() - return lr_value + self.learning_rate.step(self.last_epoch - self.warmup_steps) + return self.learning_rate() return self.learning_rate -- GitLab