未验证 提交 511e204e 编写于 作者: Z Zhou Wei 提交者: GitHub

LRScheduler.get_lr should not update lr in LinearWarmup (#31843)

上级 6472d620
...@@ -537,6 +537,18 @@ class TestLRScheduler(unittest.TestCase): ...@@ -537,6 +537,18 @@ class TestLRScheduler(unittest.TestCase):
self._test_dygraph(python_func, paddle_api, kwarg, place) self._test_dygraph(python_func, paddle_api, kwarg, place)
paddle.enable_static() paddle.enable_static()
def test_linear_warmp(self):
natural_lr = paddle.optimizer.lr.NaturalExpDecay(
learning_rate=0.5, gamma=0.1)
natural_lr_warmup = paddle.optimizer.lr.LinearWarmup(
learning_rate=natural_lr, warmup_steps=10, start_lr=0.0, end_lr=0.1)
for idx in range(30):
if idx >= 10:
self.assertEqual(natural_lr_warmup.get_lr(),
natural_lr.get_lr())
natural_lr.step()
natural_lr_warmup.step()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -786,9 +786,8 @@ class LinearWarmup(LRScheduler): ...@@ -786,9 +786,8 @@ class LinearWarmup(LRScheduler):
self.last_epoch) / float(self.warmup_steps) + self.start_lr self.last_epoch) / float(self.warmup_steps) + self.start_lr
else: else:
if isinstance(self.learning_rate, LRScheduler): if isinstance(self.learning_rate, LRScheduler):
lr_value = self.learning_rate() self.learning_rate.step(self.last_epoch - self.warmup_steps)
self.learning_rate.step() return self.learning_rate()
return lr_value
return self.learning_rate return self.learning_rate
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册