未验证 提交 33949fc5 编写于 作者: A Asthestarsfalll 提交者: GitHub

【Hackathon No.13】为 Paddle 新增 CyclicLR 优化调度器 (#40698)

* add paddle.optimizer.lr.CyclicLR

* add unittest of CyclicLR

* fix code format

* fix bug

* try

* fix CI-Coverage

* fix ValueError

* fix arguments assgin

* fix code format and retry pulling develop to pass ci

* fix typo

* Refactor

* fix function-redefined in test_lr_scheduler.py

* update

* fix conflict

* update

* gamma->exp_gamma

* polish docs

* fix code-style

* adjust code format again

* change format of __all__ in lr.py
上级 cad139a7
......@@ -389,6 +389,53 @@ def one_cycle_lr(epoch_num,
return computed_lr
def cyclic_lr(epoch_num,
base_learning_rate,
max_learning_rate,
step_size_up,
step_size_down,
mode,
exp_gamma=0.1,
scale_fn=None,
scale_mode='cycle',
verbose=False):
total_steps = step_size_up + step_size_down
step_ratio = step_size_up / total_steps
def triangular(x):
return 1.
def triangular2(x):
return 1 / (2.**(x - 1))
def exp_range(x):
return exp_gamma**x
if scale_fn is None:
if mode == 'triangular':
scale_fn = triangular
scale_mode = 'cycle'
elif mode == 'triangular2':
scale_fn = triangular2
scale_mode = 'cycle'
elif mode == 'exp_range':
scale_fn = exp_range
scale_mode = 'iterations'
cycle = math.floor(1 + epoch_num / total_steps)
iterations = epoch_num
x = 1. + epoch_num / total_steps - cycle
if x <= step_ratio:
scale_factor = x / step_ratio
else:
scale_factor = (x - 1) / (step_ratio - 1)
base_height = (max_learning_rate - base_learning_rate) * scale_factor
return base_learning_rate + base_height * scale_fn(eval(scale_mode))
class TestLRScheduler(unittest.TestCase):
def _test_static(self, python_func, paddle_api, kwarg, place):
......@@ -533,35 +580,89 @@ class TestLRScheduler(unittest.TestCase):
paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5,
milestones=[1, 2, 3],
gamma=2)
# check type of max_learning_rate
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(max_learning_rate='test',
total_steps=20)
# check value of max_learning_rate
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(max_learning_rate=-1.5,
total_steps=20)
# check type of end_learning_rate
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1,
total_steps=20,
end_learning_rate='test')
# check value of end_learning_rate
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1,
total_steps=20,
end_learning_rate=-1)
# check type of total_steps
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1,
total_steps='test')
# check value of total_steps
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1,
total_steps=-10)
# check value of anneal_strategy
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1,
total_steps=20,
anneal_strategy='test')
# check value of phase_pct when three_phase is True
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1,
total_steps=20,
phase_pct=0.6,
three_phase=True)
# check type of max_learning_rate
with self.assertRaises(TypeError):
paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
max_learning_rate='test',
step_size_up=10)
# check value of max_learning_rate
with self.assertRaises(ValueError):
paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
max_learning_rate=-1,
step_size_up=10)
# check type of step_size_up
with self.assertRaises(TypeError):
paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up='test')
# check value of step_size_up
with self.assertRaises(ValueError):
paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up=-1)
# check type of step_size_down
with self.assertRaises(TypeError):
paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up=500,
step_size_down='test')
# check type of step_size_down
with self.assertRaises(ValueError):
paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up=500,
step_size_down=-1)
# check value of mode
with self.assertRaises(ValueError):
paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up=500,
step_size_down=500,
mode='test')
# check type value of scale_mode
with self.assertRaises(ValueError):
paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up=500,
step_size_down=-1,
scale_mode='test')
func_api_kwargs = [
(noam_lr, paddle.optimizer.lr.NoamDecay, {
......@@ -671,6 +772,61 @@ class TestLRScheduler(unittest.TestCase):
"anneal_strategy": 'linear',
"phase_pct": 0.2,
"three_phase": True,
}),
(cyclic_lr, paddle.optimizer.lr.CyclicLR, {
"base_learning_rate": 0.5,
"max_learning_rate": 1.0,
"step_size_up": 15,
"step_size_down": 5,
"mode": 'triangular',
"exp_gamma": 1.,
"scale_fn": None,
"scale_mode": 'cycle',
"verbose": False
}),
(cyclic_lr, paddle.optimizer.lr.CyclicLR, {
"base_learning_rate": 0.5,
"max_learning_rate": 1.0,
"step_size_up": 15,
"step_size_down": 5,
"mode": 'triangular2',
"exp_gamma": 1.,
"scale_fn": None,
"scale_mode": 'cycle',
"verbose": False
}),
(cyclic_lr, paddle.optimizer.lr.CyclicLR, {
"base_learning_rate": 0.5,
"max_learning_rate": 1.0,
"step_size_up": 15,
"step_size_down": 5,
"mode": 'exp_range',
"exp_gamma": 0.8,
"scale_fn": None,
"scale_mode": 'cycle',
"verbose": False
}),
(cyclic_lr, paddle.optimizer.lr.CyclicLR, {
"base_learning_rate": 0.5,
"max_learning_rate": 1.0,
"step_size_up": 15,
"step_size_down": 5,
"mode": 'exp_range',
"exp_gamma": 1.,
"scale_fn": lambda x: 0.95**x,
"scale_mode": 'cycle',
"verbose": False
}),
(cyclic_lr, paddle.optimizer.lr.CyclicLR, {
"base_learning_rate": 0.5,
"max_learning_rate": 1.0,
"step_size_up": 15,
"step_size_down": 5,
"mode": 'exp_range',
"exp_gamma": 1.,
"scale_fn": lambda x: 0.95,
"scale_mode": 'iterations',
"verbose": False
})
]
......
......@@ -20,10 +20,22 @@ import paddle.fluid.core as core
from ..fluid.framework import _in_legacy_dygraph
__all__ = [ # noqa
'LRScheduler', 'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay',
'InverseTimeDecay', 'PolynomialDecay', 'LinearWarmup', 'ExponentialDecay',
'MultiStepDecay', 'StepDecay', 'LambdaDecay', 'ReduceOnPlateau',
'CosineAnnealingDecay', 'MultiplicativeDecay', 'OneCycleLR'
'LRScheduler',
'NoamDecay',
'PiecewiseDecay',
'NaturalExpDecay',
'InverseTimeDecay',
'PolynomialDecay',
'LinearWarmup',
'ExponentialDecay',
'MultiStepDecay',
'StepDecay',
'LambdaDecay',
'ReduceOnPlateau',
'CosineAnnealingDecay',
'MultiplicativeDecay',
'OneCycleLR',
'CyclicLR',
]
......@@ -1681,7 +1693,7 @@ class OneCycleLR(LRScheduler):
if not isinstance(max_learning_rate, (float, int)):
raise TypeError(
"'max_learning_rate' must be 'float' or 'int', but received {}".
format(type(total_steps)))
format(type(max_learning_rate)))
if max_learning_rate < 0:
raise ValueError("'max_learning_rate' must be a positive integer.")
......@@ -1689,7 +1701,7 @@ class OneCycleLR(LRScheduler):
if not isinstance(end_learning_rate, (float, int)):
raise TypeError(
"'end_learning_rate' must be 'float' or 'int', but received {}".
format(type(total_steps)))
format(type(end_learning_rate)))
if end_learning_rate < 0:
raise ValueError("'end_learning_rate' must be a positive integer.")
......@@ -1792,3 +1804,205 @@ class OneCycleLR(LRScheduler):
percentage = (current_step - self._step_config[i]) / step_size
return self.anneal_func(self._lr_config[i],
self._lr_config[i + 1], percentage)
class CyclicLR(LRScheduler):
r"""
Set the learning rate according to the cyclic learning rate (CLR) scheduler.
The scheduler regards the process of learning rate adjustment as one cycle after another.
It cycles the learning rate between two boundaries with a constant frequency.
The distance between the two boundaries can be scaled on a per-iteration or per-cycle basis.
It has been proposed in `Cyclic Learning Rates for Training Neural Networks <https://arxiv.org/abs/1506.01186>`_.
According to the paper, the cyclic learning rate schedule has three build-in scale methods:
* "triangular": A basic triangular cycle without any amplitude scaling.
* "triangular2": A basic triangular cycle that reduce initial amplitude by half each cycle.
* "exp_range": A cycle that scales initial amplitude by scale function which is defined as :math:`gamma^{iterations}` .
The initial amplitude is defined as max_learning_rate - base_learning_rate.
Also note that you should update learning rate each step.
Args:
base_learning_rate (float): Initial learning rate, which is the lower boundary in the cycle. The paper recommends
that set the base_learning_rate to 1/3 or 1/4 of max_learning_rate.
max_learning_rate (float): Maximum learning rate in the cycle. It defines the cycle amplitude as above.
Since there is some scaling operation during process of learning rate adjustment,
max_learning_rate may not actually be reached.
step_size_up (int): Number of training steps, which is used to increase learning rate in a cycle.
The step size of one cycle will be defined by step_size_up + step_size_down. According to the paper, step
size should be set as at least 3 or 4 times steps in one epoch.
step_size_down (int, optional): Number of training steps, which is used to decrease learning rate in a cycle.
If not specified, it's value will initialize to `` step_size_up `` . Default: None
mode (str, optional): one of 'triangular', 'triangular2' or 'exp_range'.
If scale_fn is specified, this argument will be ignored. Default: 'triangular'
exp_gamma (float): Constant in 'exp_range' scaling function: exp_gamma**iterations. Used only when mode = 'exp_range'. Default: 1.0
scale_fn (function, optional): A custom scaling function, which is used to replace three build-in methods.
It should only have one argument. For all x >= 0, 0 <= scale_fn(x) <= 1.
If specified, then 'mode' will be ignored. Default: None
scale_mode (str, optional): One of 'cycle' or 'iterations'. Defines whether scale_fn is evaluated on cycle
number or cycle iterations (total iterations since start of training). Default: 'cycle'
last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate.
verbose: (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``CyclicLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dynamic graph mode
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
for epoch in range(5):
for batch_id in range(20):
x = paddle.uniform([10, 10])
out = linear(x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_gradients()
scheduler.step() # You should update learning rate each step
# train on static graph mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[None, 4, 5])
y = paddle.static.data(name='y', shape=[None, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(5):
for batch_id in range(20):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=loss.name)
scheduler.step() # You should update learning rate each step
"""
def __init__(self,
base_learning_rate,
max_learning_rate,
step_size_up,
step_size_down=None,
mode='triangular',
exp_gamma=1.,
scale_fn=None,
scale_mode='cycle',
last_epoch=-1,
verbose=False):
# check type and value of max_learning_rate
if not isinstance(max_learning_rate, (float, int)):
raise TypeError(
"'max_learning_rate' must be 'float' or 'int', but received {}".
format(type(max_learning_rate)))
if max_learning_rate < 0:
raise ValueError(
"'max_learning_rate' must be a positive integer, but received {}"
.format(max_learning_rate))
# check type and value of step_size_up
if not isinstance(step_size_up, int):
raise TypeError(
"The type of 'step_size_up' must be int, but received {}".
format(type(step_size_up)))
if step_size_up <= 0:
raise ValueError(
"'step_size_up' must be a positive integer, but received {}".
format(step_size_up))
# check type and value of step_size_down
if step_size_down is not None:
if not isinstance(step_size_down, int):
raise TypeError(
"The type of 'step_size_down' must be int, but received {}".
format(type(step_size_down)))
if step_size_down <= 0:
raise ValueError(
"'step_size_down' must be a positive integer, but received {}"
.format(step_size_down))
# check type of exp_gamma
if not isinstance(exp_gamma, float):
raise TypeError(
"The type of 'exp_gamma' must be float, but received {}".format(
type(exp_gamma)))
step_size_up = float(step_size_up)
step_size_down = float(
step_size_down) if step_size_down is not None else step_size_up
self.cycle_size = step_size_up + step_size_down
self.step_up_pct = step_size_up / self.cycle_size
self.max_lr = float(max_learning_rate)
self.amplitude = self.max_lr - base_learning_rate
if mode not in ['triangular', 'triangular2', 'exp_range'
] and scale_fn is None:
raise ValueError(
"'mode' is invalid and 'scale_fn' is not specified, make sure one of 'mode' or 'scale_fn' is valid"
)
if scale_mode not in ['cycle', 'iterations']:
raise ValueError(
"'scale_mode' must be one of 'cycle' or 'iterations")
self.mode = mode
self.gamma = exp_gamma # only for exp_range mode
if scale_fn is None:
if self.mode == 'triangular':
self.scale_fn = self._triangular_scale_fn
self.scale_mode = 'cycle'
elif self.mode == 'triangular2':
self.scale_fn = self._triangular2_scale_fn
self.scale_mode = 'cycle'
elif self.mode == 'exp_range':
self.scale_fn = self._exp_range_scale_fn
self.scale_mode = 'iterations'
else:
self.scale_fn = scale_fn
self.scale_mode = scale_mode
super().__init__(base_learning_rate, last_epoch, verbose)
def _triangular_scale_fn(self, x):
return 1.
def _triangular2_scale_fn(self, x):
return 1 / (2.**(x - 1))
def _exp_range_scale_fn(self, x):
return self.gamma**x
def get_lr(self):
iterations = self.last_epoch
cycle = 1 + iterations // self.cycle_size
pct_per_cycle = 1. + iterations / self.cycle_size - cycle
if pct_per_cycle <= self.step_up_pct:
scale_factor = pct_per_cycle / self.step_up_pct
else:
scale_factor = (1 - pct_per_cycle) / (1 - self.step_up_pct)
base_height = self.amplitude * scale_factor
lr = self.base_lr + base_height * self.scale_fn(eval(self.scale_mode))
return lr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册