From b63e0ccb4a029784b38b9cb2d0d963250c0c0fda Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Wed, 28 Oct 2020 11:07:03 +0800 Subject: [PATCH] fix load check_point bug of LinearWarmup (#28280) --- .../tests/unittests/test_lr_scheduler.py | 50 +++++++++++++------ python/paddle/optimizer/lr.py | 23 ++++++++- 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index 0cdc413c2f..8c6383cd6e 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -284,11 +284,19 @@ def linear_warmup_lr(epoch_num, start_lr, end_lr, verbose=False): - if epoch_num < warmup_steps: + tmp = epoch_num - warmup_steps + if tmp < 0: return start_lr + (end_lr - start_lr) * (float(epoch_num) / float(warmup_steps)) + elif paddle.in_dynamic_mode(): + if tmp < 3: + return 0.5 + elif tmp < 6: + return 0.2 + else: + return 0.1 else: - return learning_rate + return 0.5 def multi_step_lr(epoch_num, @@ -407,6 +415,9 @@ class TestLRScheduler(unittest.TestCase): paddle.disable_static(place) x = np.random.uniform(-1, 1, [10, 10]).astype("float32") linear = paddle.nn.Linear(10, 10) + if paddle_api.__name__ == "LinearWarmup": + kwarg['learning_rate'] = paddle.optimizer.lr.PiecewiseDecay( + [3, 6], [0.5, 0.2, 0.1]) scheduler = paddle_api(**kwarg) adam = paddle.optimizer.Adam( learning_rate=scheduler, parameters=linear.parameters()) @@ -420,12 +431,26 @@ class TestLRScheduler(unittest.TestCase): adam.clear_grad() current_lr = adam.get_lr() expected_lr = python_func(epoch, **kwarg) - if paddle_api.__name__ != "CosineAnnealingDecay": - self.assertEqual(current_lr, expected_lr) - scheduler.step() - else: + if paddle_api.__name__ == "CosineAnnealingDecay": self.assertAlmostEqual(current_lr, expected_lr) scheduler.step(epoch + 1) + elif paddle_api.__name__ == "LinearWarmup": + self.assertAlmostEqual(current_lr, expected_lr) + state_dict = adam.state_dict() + scheduler1 = paddle.optimizer.lr.LinearWarmup(**kwarg) + adam1 = paddle.optimizer.Adam( + learning_rate=scheduler1, parameters=linear.parameters()) + adam1.set_state_dict(state_dict) + self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch) + self.assertEqual(scheduler.last_lr, scheduler1.last_lr) + self.assertEqual(scheduler.learning_rate.last_lr, + scheduler1.learning_rate.last_lr) + self.assertEqual(scheduler.learning_rate.last_epoch, + scheduler1.learning_rate.last_epoch) + scheduler.step() + else: + self.assertEqual(current_lr, expected_lr) + scheduler.step() def test_scheduler(self): with self.assertRaises(NotImplementedError): @@ -464,8 +489,7 @@ class TestLRScheduler(unittest.TestCase): "decay_steps": 20, "end_lr": 0, "power": 1.0, - "cycle": False, - "verbose": True + "cycle": False }), (polynomial_lr, paddle.optimizer.lr.PolynomialDecay, { "learning_rate": 0.5, "decay_steps": 20, @@ -475,10 +499,9 @@ class TestLRScheduler(unittest.TestCase): "verbose": False }), (linear_warmup_lr, paddle.optimizer.lr.LinearWarmup, { 'learning_rate': 0.5, - 'warmup_steps': 20, + 'warmup_steps': 10, 'start_lr': 0, - 'end_lr': 0.5, - "verbose": True + 'end_lr': 0.5 }), (exponential_lr, paddle.optimizer.lr.ExponentialDecay, { "learning_rate": 0.5, "gamma": 0.9, @@ -486,8 +509,7 @@ class TestLRScheduler(unittest.TestCase): }), (multi_step_lr, paddle.optimizer.lr.MultiStepDecay, { "learning_rate": 0.5, "milestones": [3, 6, 9, 15, 20], - "gamma": 0.8, - "verbose": True + "gamma": 0.8 }), (step_lr, paddle.optimizer.lr.StepDecay, { "learning_rate": 0.5, "step_size": 2, @@ -510,7 +532,7 @@ class TestLRScheduler(unittest.TestCase): for place in places: paddle.enable_static() - #self._test_static(python_func, paddle_api, kwarg, place) + self._test_static(python_func, paddle_api, kwarg, place) paddle.disable_static(place) self._test_dygraph(python_func, paddle_api, kwarg, place) paddle.enable_static() diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 051d3cf18f..80b4b2a9d0 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -365,7 +365,6 @@ class PiecewiseDecay(LRScheduler): last_epoch=last_epoch, verbose=verbose) def get_lr(self): - for i in range(len(self.boundaries)): if self.last_epoch < self.boundaries[i]: return self.values[i] @@ -750,14 +749,34 @@ class LinearWarmup(LRScheduler): end_lr, start_lr) super(LinearWarmup, self).__init__(start_lr, last_epoch, verbose) + def state_dict(self): + """ + Returns the state of the LinearWarmup scheduler as a :class:`dict`. + + It is a subset of ``self.__dict__`` . + """ + state_dict = super(LinearWarmup, self).state_dict() + if isinstance(self.learning_rate, LRScheduler): + state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict() + return state_dict + + def set_state_dict(self, state_dict): + """ + Loads state_dict for LinearWarmup scheduler. + """ + super(LinearWarmup, self).set_state_dict(state_dict) + if isinstance(self.learning_rate, LRScheduler): + self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"]) + def get_lr(self): if self.last_epoch < self.warmup_steps: return (self.end_lr - self.start_lr) * float( self.last_epoch) / float(self.warmup_steps) + self.start_lr else: if isinstance(self.learning_rate, LRScheduler): + lr_value = self.learning_rate() self.learning_rate.step() - return self.learning_rate() + return lr_value return self.learning_rate -- GitLab