未验证 提交 8ffebb5a 编写于 作者: A Asthestarsfalll 提交者: GitHub

【PaddlePaddle Hackathon 2】12、为 Paddle 新增 OneCycleLR 优化调度器 (#41825)

* add OneCycleLR

* add missing total_steps

* try

* update

* fix conflict bug

* fix typo

* update doc

* update and polish

* Refactor

* update

* change end_lr to end_learning_rate
上级 4b355ff9
......@@ -321,6 +321,70 @@ def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False):
return learning_rate * math.pow(gamma, epoch_num // step_size)
def one_cycle_lr(epoch_num,
max_learning_rate,
total_steps,
divide_factor=25,
end_learning_rate=0.0001,
phase_pct=0.3,
anneal_strategy='cos',
three_phase=False,
verbose=False):
initial_lr = max_learning_rate / divide_factor
if three_phase:
_end_steps = [
float(phase_pct * total_steps) - 1,
float(2 * phase_pct * total_steps) - 2, total_steps - 1
]
_schedule_phases = [
{
'start_lr': initial_lr,
'end_lr': max_learning_rate,
},
{
'start_lr': max_learning_rate,
'end_lr': initial_lr,
},
{
'start_lr': initial_lr,
'end_lr': end_learning_rate,
},
]
else:
_end_steps = [float(phase_pct * total_steps) - 1, total_steps - 1]
_schedule_phases = [
{
'start_lr': initial_lr,
'end_lr': max_learning_rate,
},
{
'start_lr': max_learning_rate,
'end_lr': end_learning_rate,
},
]
if anneal_strategy == 'cos':
def anneal_func(start, end, pct):
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
else:
def anneal_func(start, end, pct):
return (end - start) * pct + start
start_step = 0
for i, phase in enumerate(_schedule_phases):
end_step = _end_steps[i]
if epoch_num <= end_step or i == len(_schedule_phases) - 1:
pct = (epoch_num - start_step) / (end_step - start_step)
computed_lr = anneal_func(phase['start_lr'], phase['end_lr'], pct)
break
start_step = end_step
return computed_lr
class TestLRScheduler(unittest.TestCase):
def _test_static(self, python_func, paddle_api, kwarg, place):
scheduler = paddle_api(**kwarg)
......@@ -467,6 +531,33 @@ class TestLRScheduler(unittest.TestCase):
with self.assertRaises(ValueError):
paddle.optimizer.lr.MultiStepDecay(
learning_rate=0.5, milestones=[1, 2, 3], gamma=2)
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate='test', total_steps=20)
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=-1.5, total_steps=20)
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1, total_steps=20, end_learning_rate='test')
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1, total_steps=20, end_learning_rate=-1)
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1, total_steps='test')
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1, total_steps=-10)
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1, total_steps=20, anneal_strategy='test')
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1,
total_steps=20,
phase_pct=0.6,
three_phase=True)
func_api_kwargs = [(noam_lr, paddle.optimizer.lr.NoamDecay, {
"d_model": 0.01,
......@@ -527,6 +618,38 @@ class TestLRScheduler(unittest.TestCase):
"learning_rate": 0.5,
"T_max": 10,
"verbose": False
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"max_learning_rate": 0.1,
"total_steps": 20,
"divide_factor": 5,
"end_learning_rate": 0.0001,
"anneal_strategy": 'cos',
"phase_pct": 0.3,
"three_phase": False,
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"max_learning_rate": 0.5,
"total_steps": 20,
"divide_factor": 10,
"end_learning_rate": 0.001,
"anneal_strategy": 'linear',
"phase_pct": 0.4,
"three_phase": False,
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"max_learning_rate": 1.0,
"total_steps": 20,
"divide_factor": 9,
"end_learning_rate": 0.0001,
"anneal_strategy": 'cos',
"phase_pct": 0.3,
"three_phase": True,
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"max_learning_rate": 0.3,
"total_steps": 20,
"divide_factor": 25,
"end_learning_rate": 0.0005,
"anneal_strategy": 'linear',
"phase_pct": 0.2,
"three_phase": True,
})]
for python_func, paddle_api, kwarg in func_api_kwargs:
......
......@@ -33,7 +33,8 @@ __all__ = [ # noqa
'LambdaDecay',
'ReduceOnPlateau',
'CosineAnnealingDecay',
'MultiplicativeDecay'
'MultiplicativeDecay',
'OneCycleLR'
]
......@@ -1591,3 +1592,212 @@ class MultiplicativeDecay(LRScheduler):
for epoch in range(1, self.last_epoch + 1):
cur_lr = cur_lr * self.lr_lambda(epoch)
return cur_lr
class OneCycleLR(LRScheduler):
r"""
Sets the learning rate according to the one cycle learning rate scheduler.
The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then
from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate.
It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates <https://arxiv.org/abs/1708.07120>`_.
Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle,
which claims that “unpublished work has shown even better results by using only two phases”.
If you want the behaviour of this scheduler to be consistent with the paper, please set ``three_phase=True`` .
Also note that you should update learning rate each step.
Args:
max_learning_rate (float): The maximum learning rate. It is a python float number.
Functionally, it defines the initial learning rate by ``divide_factor`` .
total_steps (int): Number of total training steps.
divide_factor (float): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Default: 25.
end_learning_rate (float, optional): The minimum learning rate during training, it should be much less than initial learning rate.
phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3.
anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing,
'linear' for linear annealing. Default: 'cos'.
three_phase (bool, optional): Whether to use three phase.
If ``True``:
1. The learning rate will first increase from initial learning rate to maximum learning rate.
2. Then it will decrease to initial learning rate. Number of step in this phase is the same as the one in first phase.
3. Finally, it will decrease to minimum learning rate which is much less than initial learning rate.
If ``False``:
1. The learning rate will increase to maximum learning rate.
2. Then it will directly decrease to minimum learning rate.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``OneCycleLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dynamic graph mode
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
for epoch in range(5):
for batch_id in range(20):
x = paddle.uniform([10, 10])
out = linear(x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_gradients()
scheduler.step() # You should update learning rate each step
# train on static graph mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[None, 4, 5])
y = paddle.static.data(name='y', shape=[None, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(5):
for batch_id in range(20):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=loss.name)
scheduler.step() # You should update learning rate each step
"""
def __init__(self,
max_learning_rate,
total_steps,
divide_factor=25.,
end_learning_rate=0.0001,
phase_pct=0.3,
anneal_strategy='cos',
three_phase=False,
last_epoch=-1,
verbose=False):
# Check type and value of max_learning_rate
if not isinstance(max_learning_rate, (float, int)):
raise TypeError(
"'max_learning_rate' must be 'float' or 'int', but received {}".
format(type(total_steps)))
if max_learning_rate < 0:
raise ValueError("'max_learning_rate' must be a positive integer.")
# Check type and value of end_learning_rate
if not isinstance(end_learning_rate, (float, int)):
raise TypeError(
"'end_learning_rate' must be 'float' or 'int', but received {}".
format(type(total_steps)))
if end_learning_rate < 0:
raise ValueError("'end_learning_rate' must be a positive integer.")
# Check type and value of total_steps
if not isinstance(total_steps, int):
raise TypeError("'total_step' must be 'int', but received {}".
format(type(total_steps)))
if total_steps <= 0:
raise ValueError("'total_step' must be a positive integer.")
self.total_steps = total_steps
# Check type and value of pac_start
if not isinstance(phase_pct, float):
raise TypeError("'phase_pct' must be 'float', but received {}".
format(type(phase_pct)))
if phase_pct < 0 or phase_pct > 1:
raise ValueError(
"'phase_pct' must be between 0 and 1, but received {}".format(
phase_pct))
# Check type and value of divide_factor
if not isinstance(divide_factor, (float, int)):
raise TypeError(
"'divide_factor' must be 'float' or 'int', but received {}".
format(type(divide_factor)))
initial_lr = max_learning_rate / float(divide_factor)
min_lr = float(end_learning_rate)
if three_phase:
if phase_pct >= 0.5:
raise ValueError(
"When three_phase is True, 'phase_pct' must be less than 0.5"
)
# start step and end step of each phase.
self._step_config = [
0,
phase_pct * self.total_steps - 1,
2 * phase_pct * self.total_steps - 2,
self.total_steps - 1,
self.total_steps - 1, # for the last step.
]
# step size of each phase.
self._steps_size = [
self._step_config[1] - self._step_config[0],
self._step_config[2] - self._step_config[1],
self._step_config[3] - self._step_config[2],
self._step_config[3] -
self._step_config[2], # for the last step.
]
# start lr and end lr of each phase.
self._lr_config = [
initial_lr, max_learning_rate, initial_lr, min_lr
]
else:
self._step_config = [
0, phase_pct * self.total_steps - 1, self.total_steps - 1,
self.total_steps - 1
]
self._steps_size = [
self._step_config[1] - self._step_config[0],
self._step_config[2] - self._step_config[1],
self._step_config[2] - self._step_config[1],
]
self._lr_config = [initial_lr, max_learning_rate, min_lr]
# Check anneal_strategy
if anneal_strategy == 'cos':
self.anneal_func = self._cos_annealing
elif anneal_strategy == 'linear':
self.anneal_func = self._linear_annealing
else:
raise ValueError(
"'anneal_strategy' must by one of 'cos' or 'linear', but received {}".
format(anneal_strategy))
super(OneCycleLR, self).__init__(initial_lr, last_epoch, verbose)
def _cos_annealing(self, start_lr, end_lr, pct):
cos_out = math.cos(math.pi * pct) + 1
return end_lr + (start_lr - end_lr) / 2.0 * cos_out
def _linear_annealing(self, start_lr, end_lr, pct):
return (end_lr - start_lr) * pct + start_lr
def get_lr(self):
current_step = self.last_epoch
if current_step > self.total_steps:
raise ValueError(
"Tried to step {} times. However the number of total steps is {}"
.format(current_step, self.total_steps))
for (i, (end_step, step_size)
) in enumerate(zip(self._step_config[1:], self._steps_size)):
# i == len(self._lr_config) - 2 catch the last step, otherwise it will return None.
if current_step <= end_step or i == len(self._lr_config) - 2:
# self._step_config[i] means start step of a phase.
percentage = (current_step - self._step_config[i]) / step_size
return self.anneal_func(self._lr_config[i],
self._lr_config[i + 1], percentage)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册