未验证 提交 b63e0ccb 编写于 作者: Z Zhou Wei 提交者: GitHub

fix load check_point bug of LinearWarmup (#28280)

上级 0b678d40
......@@ -284,11 +284,19 @@ def linear_warmup_lr(epoch_num,
start_lr,
end_lr,
verbose=False):
if epoch_num < warmup_steps:
tmp = epoch_num - warmup_steps
if tmp < 0:
return start_lr + (end_lr - start_lr) * (float(epoch_num) /
float(warmup_steps))
elif paddle.in_dynamic_mode():
if tmp < 3:
return 0.5
elif tmp < 6:
return 0.2
else:
return 0.1
else:
return learning_rate
return 0.5
def multi_step_lr(epoch_num,
......@@ -407,6 +415,9 @@ class TestLRScheduler(unittest.TestCase):
paddle.disable_static(place)
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
if paddle_api.__name__ == "LinearWarmup":
kwarg['learning_rate'] = paddle.optimizer.lr.PiecewiseDecay(
[3, 6], [0.5, 0.2, 0.1])
scheduler = paddle_api(**kwarg)
adam = paddle.optimizer.Adam(
learning_rate=scheduler, parameters=linear.parameters())
......@@ -420,12 +431,26 @@ class TestLRScheduler(unittest.TestCase):
adam.clear_grad()
current_lr = adam.get_lr()
expected_lr = python_func(epoch, **kwarg)
if paddle_api.__name__ != "CosineAnnealingDecay":
self.assertEqual(current_lr, expected_lr)
scheduler.step()
else:
if paddle_api.__name__ == "CosineAnnealingDecay":
self.assertAlmostEqual(current_lr, expected_lr)
scheduler.step(epoch + 1)
elif paddle_api.__name__ == "LinearWarmup":
self.assertAlmostEqual(current_lr, expected_lr)
state_dict = adam.state_dict()
scheduler1 = paddle.optimizer.lr.LinearWarmup(**kwarg)
adam1 = paddle.optimizer.Adam(
learning_rate=scheduler1, parameters=linear.parameters())
adam1.set_state_dict(state_dict)
self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch)
self.assertEqual(scheduler.last_lr, scheduler1.last_lr)
self.assertEqual(scheduler.learning_rate.last_lr,
scheduler1.learning_rate.last_lr)
self.assertEqual(scheduler.learning_rate.last_epoch,
scheduler1.learning_rate.last_epoch)
scheduler.step()
else:
self.assertEqual(current_lr, expected_lr)
scheduler.step()
def test_scheduler(self):
with self.assertRaises(NotImplementedError):
......@@ -464,8 +489,7 @@ class TestLRScheduler(unittest.TestCase):
"decay_steps": 20,
"end_lr": 0,
"power": 1.0,
"cycle": False,
"verbose": True
"cycle": False
}), (polynomial_lr, paddle.optimizer.lr.PolynomialDecay, {
"learning_rate": 0.5,
"decay_steps": 20,
......@@ -475,10 +499,9 @@ class TestLRScheduler(unittest.TestCase):
"verbose": False
}), (linear_warmup_lr, paddle.optimizer.lr.LinearWarmup, {
'learning_rate': 0.5,
'warmup_steps': 20,
'warmup_steps': 10,
'start_lr': 0,
'end_lr': 0.5,
"verbose": True
'end_lr': 0.5
}), (exponential_lr, paddle.optimizer.lr.ExponentialDecay, {
"learning_rate": 0.5,
"gamma": 0.9,
......@@ -486,8 +509,7 @@ class TestLRScheduler(unittest.TestCase):
}), (multi_step_lr, paddle.optimizer.lr.MultiStepDecay, {
"learning_rate": 0.5,
"milestones": [3, 6, 9, 15, 20],
"gamma": 0.8,
"verbose": True
"gamma": 0.8
}), (step_lr, paddle.optimizer.lr.StepDecay, {
"learning_rate": 0.5,
"step_size": 2,
......@@ -510,7 +532,7 @@ class TestLRScheduler(unittest.TestCase):
for place in places:
paddle.enable_static()
#self._test_static(python_func, paddle_api, kwarg, place)
self._test_static(python_func, paddle_api, kwarg, place)
paddle.disable_static(place)
self._test_dygraph(python_func, paddle_api, kwarg, place)
paddle.enable_static()
......
......@@ -365,7 +365,6 @@ class PiecewiseDecay(LRScheduler):
last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
for i in range(len(self.boundaries)):
if self.last_epoch < self.boundaries[i]:
return self.values[i]
......@@ -750,14 +749,34 @@ class LinearWarmup(LRScheduler):
end_lr, start_lr)
super(LinearWarmup, self).__init__(start_lr, last_epoch, verbose)
def state_dict(self):
"""
Returns the state of the LinearWarmup scheduler as a :class:`dict`.
It is a subset of ``self.__dict__`` .
"""
state_dict = super(LinearWarmup, self).state_dict()
if isinstance(self.learning_rate, LRScheduler):
state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict()
return state_dict
def set_state_dict(self, state_dict):
"""
Loads state_dict for LinearWarmup scheduler.
"""
super(LinearWarmup, self).set_state_dict(state_dict)
if isinstance(self.learning_rate, LRScheduler):
self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"])
def get_lr(self):
if self.last_epoch < self.warmup_steps:
return (self.end_lr - self.start_lr) * float(
self.last_epoch) / float(self.warmup_steps) + self.start_lr
else:
if isinstance(self.learning_rate, LRScheduler):
lr_value = self.learning_rate()
self.learning_rate.step()
return self.learning_rate()
return lr_value
return self.learning_rate
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册