未验证 提交 b63e0ccb 编写于 作者: Z Zhou Wei 提交者: GitHub

fix load check_point bug of LinearWarmup (#28280)

上级 0b678d40
...@@ -284,11 +284,19 @@ def linear_warmup_lr(epoch_num, ...@@ -284,11 +284,19 @@ def linear_warmup_lr(epoch_num,
start_lr, start_lr,
end_lr, end_lr,
verbose=False): verbose=False):
if epoch_num < warmup_steps: tmp = epoch_num - warmup_steps
if tmp < 0:
return start_lr + (end_lr - start_lr) * (float(epoch_num) / return start_lr + (end_lr - start_lr) * (float(epoch_num) /
float(warmup_steps)) float(warmup_steps))
elif paddle.in_dynamic_mode():
if tmp < 3:
return 0.5
elif tmp < 6:
return 0.2
else:
return 0.1
else: else:
return learning_rate return 0.5
def multi_step_lr(epoch_num, def multi_step_lr(epoch_num,
...@@ -407,6 +415,9 @@ class TestLRScheduler(unittest.TestCase): ...@@ -407,6 +415,9 @@ class TestLRScheduler(unittest.TestCase):
paddle.disable_static(place) paddle.disable_static(place)
x = np.random.uniform(-1, 1, [10, 10]).astype("float32") x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10) linear = paddle.nn.Linear(10, 10)
if paddle_api.__name__ == "LinearWarmup":
kwarg['learning_rate'] = paddle.optimizer.lr.PiecewiseDecay(
[3, 6], [0.5, 0.2, 0.1])
scheduler = paddle_api(**kwarg) scheduler = paddle_api(**kwarg)
adam = paddle.optimizer.Adam( adam = paddle.optimizer.Adam(
learning_rate=scheduler, parameters=linear.parameters()) learning_rate=scheduler, parameters=linear.parameters())
...@@ -420,12 +431,26 @@ class TestLRScheduler(unittest.TestCase): ...@@ -420,12 +431,26 @@ class TestLRScheduler(unittest.TestCase):
adam.clear_grad() adam.clear_grad()
current_lr = adam.get_lr() current_lr = adam.get_lr()
expected_lr = python_func(epoch, **kwarg) expected_lr = python_func(epoch, **kwarg)
if paddle_api.__name__ != "CosineAnnealingDecay": if paddle_api.__name__ == "CosineAnnealingDecay":
self.assertEqual(current_lr, expected_lr)
scheduler.step()
else:
self.assertAlmostEqual(current_lr, expected_lr) self.assertAlmostEqual(current_lr, expected_lr)
scheduler.step(epoch + 1) scheduler.step(epoch + 1)
elif paddle_api.__name__ == "LinearWarmup":
self.assertAlmostEqual(current_lr, expected_lr)
state_dict = adam.state_dict()
scheduler1 = paddle.optimizer.lr.LinearWarmup(**kwarg)
adam1 = paddle.optimizer.Adam(
learning_rate=scheduler1, parameters=linear.parameters())
adam1.set_state_dict(state_dict)
self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch)
self.assertEqual(scheduler.last_lr, scheduler1.last_lr)
self.assertEqual(scheduler.learning_rate.last_lr,
scheduler1.learning_rate.last_lr)
self.assertEqual(scheduler.learning_rate.last_epoch,
scheduler1.learning_rate.last_epoch)
scheduler.step()
else:
self.assertEqual(current_lr, expected_lr)
scheduler.step()
def test_scheduler(self): def test_scheduler(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
...@@ -464,8 +489,7 @@ class TestLRScheduler(unittest.TestCase): ...@@ -464,8 +489,7 @@ class TestLRScheduler(unittest.TestCase):
"decay_steps": 20, "decay_steps": 20,
"end_lr": 0, "end_lr": 0,
"power": 1.0, "power": 1.0,
"cycle": False, "cycle": False
"verbose": True
}), (polynomial_lr, paddle.optimizer.lr.PolynomialDecay, { }), (polynomial_lr, paddle.optimizer.lr.PolynomialDecay, {
"learning_rate": 0.5, "learning_rate": 0.5,
"decay_steps": 20, "decay_steps": 20,
...@@ -475,10 +499,9 @@ class TestLRScheduler(unittest.TestCase): ...@@ -475,10 +499,9 @@ class TestLRScheduler(unittest.TestCase):
"verbose": False "verbose": False
}), (linear_warmup_lr, paddle.optimizer.lr.LinearWarmup, { }), (linear_warmup_lr, paddle.optimizer.lr.LinearWarmup, {
'learning_rate': 0.5, 'learning_rate': 0.5,
'warmup_steps': 20, 'warmup_steps': 10,
'start_lr': 0, 'start_lr': 0,
'end_lr': 0.5, 'end_lr': 0.5
"verbose": True
}), (exponential_lr, paddle.optimizer.lr.ExponentialDecay, { }), (exponential_lr, paddle.optimizer.lr.ExponentialDecay, {
"learning_rate": 0.5, "learning_rate": 0.5,
"gamma": 0.9, "gamma": 0.9,
...@@ -486,8 +509,7 @@ class TestLRScheduler(unittest.TestCase): ...@@ -486,8 +509,7 @@ class TestLRScheduler(unittest.TestCase):
}), (multi_step_lr, paddle.optimizer.lr.MultiStepDecay, { }), (multi_step_lr, paddle.optimizer.lr.MultiStepDecay, {
"learning_rate": 0.5, "learning_rate": 0.5,
"milestones": [3, 6, 9, 15, 20], "milestones": [3, 6, 9, 15, 20],
"gamma": 0.8, "gamma": 0.8
"verbose": True
}), (step_lr, paddle.optimizer.lr.StepDecay, { }), (step_lr, paddle.optimizer.lr.StepDecay, {
"learning_rate": 0.5, "learning_rate": 0.5,
"step_size": 2, "step_size": 2,
...@@ -510,7 +532,7 @@ class TestLRScheduler(unittest.TestCase): ...@@ -510,7 +532,7 @@ class TestLRScheduler(unittest.TestCase):
for place in places: for place in places:
paddle.enable_static() paddle.enable_static()
#self._test_static(python_func, paddle_api, kwarg, place) self._test_static(python_func, paddle_api, kwarg, place)
paddle.disable_static(place) paddle.disable_static(place)
self._test_dygraph(python_func, paddle_api, kwarg, place) self._test_dygraph(python_func, paddle_api, kwarg, place)
paddle.enable_static() paddle.enable_static()
......
...@@ -365,7 +365,6 @@ class PiecewiseDecay(LRScheduler): ...@@ -365,7 +365,6 @@ class PiecewiseDecay(LRScheduler):
last_epoch=last_epoch, verbose=verbose) last_epoch=last_epoch, verbose=verbose)
def get_lr(self): def get_lr(self):
for i in range(len(self.boundaries)): for i in range(len(self.boundaries)):
if self.last_epoch < self.boundaries[i]: if self.last_epoch < self.boundaries[i]:
return self.values[i] return self.values[i]
...@@ -750,14 +749,34 @@ class LinearWarmup(LRScheduler): ...@@ -750,14 +749,34 @@ class LinearWarmup(LRScheduler):
end_lr, start_lr) end_lr, start_lr)
super(LinearWarmup, self).__init__(start_lr, last_epoch, verbose) super(LinearWarmup, self).__init__(start_lr, last_epoch, verbose)
def state_dict(self):
"""
Returns the state of the LinearWarmup scheduler as a :class:`dict`.
It is a subset of ``self.__dict__`` .
"""
state_dict = super(LinearWarmup, self).state_dict()
if isinstance(self.learning_rate, LRScheduler):
state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict()
return state_dict
def set_state_dict(self, state_dict):
"""
Loads state_dict for LinearWarmup scheduler.
"""
super(LinearWarmup, self).set_state_dict(state_dict)
if isinstance(self.learning_rate, LRScheduler):
self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"])
def get_lr(self): def get_lr(self):
if self.last_epoch < self.warmup_steps: if self.last_epoch < self.warmup_steps:
return (self.end_lr - self.start_lr) * float( return (self.end_lr - self.start_lr) * float(
self.last_epoch) / float(self.warmup_steps) + self.start_lr self.last_epoch) / float(self.warmup_steps) + self.start_lr
else: else:
if isinstance(self.learning_rate, LRScheduler): if isinstance(self.learning_rate, LRScheduler):
lr_value = self.learning_rate()
self.learning_rate.step() self.learning_rate.step()
return self.learning_rate() return lr_value
return self.learning_rate return self.learning_rate
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册