未验证 提交 0f5949a1 编写于 作者: H HydrogenSulfate 提交者: GitHub

Merge pull request #2308 from HydrogenSulfate/refactor_lr

refactor learning_rate.py
......@@ -15,117 +15,218 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from paddle.optimizer import lr
from paddle.optimizer.lr import LRScheduler
from abc import abstractmethod
from typing import Union
from paddle.optimizer import lr
from ppcls.utils import logger
class Linear(object):
class LRBase(object):
"""Base class for custom learning rates
Args:
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
learning_rate (float): learning rate
warmup_epoch (int): number of warmup epoch(s)
warmup_start_lr (float): start learning rate within warmup
last_epoch (int): last epoch
by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
verbose (bool): If True, prints a message to stdout for each update. Defaults to False
"""
def __init__(self,
epochs: int,
step_each_epoch: int,
learning_rate: float,
warmup_epoch: int,
warmup_start_lr: float,
last_epoch: int,
by_epoch: bool,
verbose: bool=False) -> None:
"""Initialize and record the necessary parameters
"""
super(LRBase, self).__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.epochs = epochs
self.step_each_epoch = step_each_epoch
self.learning_rate = learning_rate
self.warmup_epoch = warmup_epoch
self.warmup_steps = round(
self.warmup_epoch *
self.step_each_epoch) if by_epoch else self.warmup_epoch
self.warmup_start_lr = warmup_start_lr
self.last_epoch = last_epoch
self.by_epoch = by_epoch
self.verbose = verbose
@abstractmethod
def __call__(self, *kargs, **kwargs) -> lr.LRScheduler:
"""generate an learning rate scheduler
Returns:
lr.LinearWarmup: learning rate scheduler
"""
pass
def linear_warmup(
self,
learning_rate: Union[float, lr.LRScheduler]) -> lr.LinearWarmup:
"""Add an Linear Warmup before learning_rate
Args:
learning_rate (Union[float, lr.LRScheduler]): original learning rate without warmup
Returns:
lr.LinearWarmup: learning rate scheduler with warmup
"""
Linear learning rate decay
warmup_lr = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.learning_rate,
last_epoch=self.last_epoch,
verbose=self.verbose)
return warmup_lr
class Constant(LRBase):
"""Constant learning rate
Args:
lr (float): The initial learning rate. It is a python float number.
epochs(int): The decay step size. It determines the decay cycle.
end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
power(float, optional): Power of polynomial. Default: 1.0.
warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
learning_rate (float): learning rate
warmup_epoch (int): number of warmup epoch(s)
warmup_start_lr (float): start learning rate within warmup
last_epoch (int): last epoch
by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
"""
def __init__(self,
epochs,
step_each_epoch,
learning_rate,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
by_epoch=False,
**kwargs):
super(Constant, self).__init__(epochs, step_each_epoch, learning_rate,
warmup_epoch, warmup_start_lr,
last_epoch, by_epoch)
def __call__(self):
learning_rate = lr.LRScheduler(
learning_rate=self.learning_rate, last_epoch=self.last_epoch)
def make_get_lr():
def get_lr(self):
return self.learning_rate
return get_lr
setattr(learning_rate, "get_lr", make_get_lr())
if self.warmup_steps > 0:
learning_rate = self.linear_warmup(learning_rate)
setattr(learning_rate, "by_epoch", self.by_epoch)
return learning_rate
class Linear(LRBase):
"""Linear learning rate decay
Args:
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
learning_rate (float): learning rate
end_lr (float, optional): The minimum final learning rate. Defaults to 0.0.
power (float, optional): Power of polynomial. Defaults to 1.0.
warmup_epoch (int): number of warmup epoch(s)
warmup_start_lr (float): start learning rate within warmup
last_epoch (int): last epoch
by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
"""
def __init__(self,
epochs,
step_each_epoch,
learning_rate,
end_lr=0.0,
power=1.0,
cycle=False,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
by_epoch=False,
**kwargs):
super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.learning_rate = learning_rate
self.steps = (epochs - warmup_epoch) * step_each_epoch
super(Linear, self).__init__(epochs, step_each_epoch, learning_rate,
warmup_epoch, warmup_start_lr, last_epoch,
by_epoch)
self.decay_steps = (epochs - self.warmup_epoch) * step_each_epoch
self.end_lr = end_lr
self.power = power
self.last_epoch = last_epoch
self.warmup_steps = round(warmup_epoch * step_each_epoch)
self.warmup_start_lr = warmup_start_lr
self.cycle = cycle
self.warmup_steps = round(self.warmup_epoch * step_each_epoch)
if self.by_epoch:
self.decay_steps = self.epochs - self.warmup_epoch
def __call__(self):
learning_rate = lr.PolynomialDecay(
learning_rate=self.learning_rate,
decay_steps=self.steps,
decay_steps=self.decay_steps,
end_lr=self.end_lr,
power=self.power,
cycle=self.cycle,
last_epoch=self.
last_epoch) if self.steps > 0 else self.learning_rate
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
last_epoch) if self.decay_steps > 0 else self.learning_rate
if self.warmup_steps > 0:
learning_rate = self.linear_warmup(learning_rate)
class Constant(LRScheduler):
"""
Constant learning rate
Args:
lr (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
setattr(learning_rate, "by_epoch", self.by_epoch)
return learning_rate
def __init__(self, learning_rate, last_epoch=-1, **kwargs):
self.learning_rate = learning_rate
self.last_epoch = last_epoch
super().__init__()
def get_lr(self):
return self.learning_rate
class Cosine(LRBase):
"""Cosine learning rate decay
``lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)``
class Cosine(object):
"""
Cosine learning rate decay
lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
epochs(int): total training epochs
eta_min(float): Minimum learning rate. Default: 0.0.
warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
learning_rate (float): learning rate
eta_min (float, optional): Minimum learning rate. Defaults to 0.0.
warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
last_epoch (int, optional): last epoch. Defaults to -1.
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
"""
def __init__(self,
learning_rate,
step_each_epoch,
epochs,
step_each_epoch,
learning_rate,
eta_min=0.0,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
by_epoch=False,
**kwargs):
super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.learning_rate = learning_rate
self.T_max = (epochs - warmup_epoch) * step_each_epoch
super(Cosine, self).__init__(epochs, step_each_epoch, learning_rate,
warmup_epoch, warmup_start_lr, last_epoch,
by_epoch)
self.T_max = (self.epochs - self.warmup_epoch) * self.step_each_epoch
self.eta_min = eta_min
self.last_epoch = last_epoch
self.warmup_steps = round(warmup_epoch * step_each_epoch)
self.warmup_start_lr = warmup_start_lr
if self.by_epoch:
self.T_max = self.epochs - self.warmup_epoch
def __call__(self):
learning_rate = lr.CosineAnnealingDecay(
......@@ -134,51 +235,47 @@ class Cosine(object):
eta_min=self.eta_min,
last_epoch=self.
last_epoch) if self.T_max > 0 else self.learning_rate
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
learning_rate = self.linear_warmup(learning_rate)
setattr(learning_rate, "by_epoch", self.by_epoch)
return learning_rate
class Step(object):
"""
Piecewise learning rate decay
class Step(LRBase):
"""Step learning rate decay
Args:
step_each_epoch(int): steps each epoch
learning_rate (float): The initial learning rate. It is a python float number.
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
learning_rate (float): learning rate
step_size (int): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma``. It should be less than 1.0. Default: 0.1.
warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
last_epoch (int, optional): last epoch. Defaults to -1.
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
"""
def __init__(self,
epochs,
step_each_epoch,
learning_rate,
step_size,
step_each_epoch,
epochs,
gamma,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
by_epoch=False,
**kwargs):
super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.step_size = step_each_epoch * step_size
self.learning_rate = learning_rate
super(Step, self).__init__(epochs, step_each_epoch, learning_rate,
warmup_epoch, warmup_start_lr, last_epoch,
by_epoch)
self.step_size = step_size * step_each_epoch
self.gamma = gamma
self.last_epoch = last_epoch
self.warmup_steps = round(warmup_epoch * step_each_epoch)
self.warmup_start_lr = warmup_start_lr
if self.by_epoch:
self.step_size = step_size
def __call__(self):
learning_rate = lr.StepDecay(
......@@ -186,177 +283,102 @@ class Step(object):
step_size=self.step_size,
gamma=self.gamma,
last_epoch=self.last_epoch)
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
learning_rate = self.linear_warmup(learning_rate)
setattr(learning_rate, "by_epoch", self.by_epoch)
return learning_rate
class Piecewise(object):
"""
Piecewise learning rate decay
class Piecewise(LRBase):
"""Piecewise learning rate decay
Args:
boundaries(list): A list of steps numbers. The type of element in the list is python int.
values(list): A list of learning rate values that will be picked during different epoch boundaries.
The type of element in the list is python float.
warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
by_epoch(bool): Whether lr decay by epoch. Default: False.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
decay_epochs (List[int]): A list of steps numbers. The type of element in the list is python int.
values (List[float]): A list of learning rate values that will be picked during different epoch boundaries.
warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
last_epoch (int, optional): last epoch. Defaults to -1.
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
"""
def __init__(self,
epochs,
step_each_epoch,
decay_epochs,
values,
epochs,
warmup_epoch=0,
warmup_start_lr=0.0,
by_epoch=False,
last_epoch=-1,
by_epoch=False,
**kwargs):
super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.boundaries_steps = [step_each_epoch * e for e in decay_epochs]
self.boundaries_epoch = decay_epochs
super(Piecewise,
self).__init__(epochs, step_each_epoch, values[0], warmup_epoch,
warmup_start_lr, last_epoch, by_epoch)
self.values = values
self.last_epoch = last_epoch
self.warmup_steps = round(warmup_epoch * step_each_epoch)
self.warmup_epoch = warmup_epoch
self.warmup_start_lr = warmup_start_lr
self.by_epoch = by_epoch
self.boundaries_steps = [e * step_each_epoch for e in decay_epochs]
if self.by_epoch is True:
self.boundaries_steps = decay_epochs
def __call__(self):
if self.by_epoch:
learning_rate = lr.PiecewiseDecay(
boundaries=self.boundaries_epoch,
values=self.values,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=self.warmup_start_lr,
end_lr=self.values[0],
last_epoch=self.last_epoch)
else:
learning_rate = lr.PiecewiseDecay(
boundaries=self.boundaries_steps,
values=self.values,
last_epoch=self.last_epoch)
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.values[0],
last_epoch=self.last_epoch)
learning_rate = self.linear_warmup(learning_rate)
setattr(learning_rate, "by_epoch", self.by_epoch)
return learning_rate
class MultiStepDecay(LRScheduler):
"""
Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5
milestones = [30, 50]
gamma = 0.1
if epoch < 30:
learning_rate = 0.5
elif epoch < 50:
learning_rate = 0.05
else:
learning_rate = 0.005
Args:
learning_rate (float): The initial learning rate. It is a python float number.
milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
class MultiStepDecay(LRBase):
"""MultiStepDecay learning rate decay
Returns:
``MultiStepDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dynamic graph mode
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
for epoch in range(20):
for batch_id in range(5):
x = paddle.uniform([10, 10])
out = linear(x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_gradients()
scheduler.step() # If you update learning rate each step
# scheduler.step() # If you update learning rate each epoch
# train on static graph mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[None, 4, 5])
y = paddle.static.data(name='y', shape=[None, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(5):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=loss.name)
scheduler.step() # If you update learning rate each step
# scheduler.step() # If you update learning rate each epoch
Args:
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
learning_rate (float): learning rate
milestones (List[int]): List of each boundaries. Must be increasing.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma``. It should be less than 1.0. Defaults to 0.1.
warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
last_epoch (int, optional): last epoch. Defaults to -1.
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
"""
def __init__(self,
learning_rate,
milestones,
epochs,
step_each_epoch,
learning_rate,
milestones,
gamma=0.1,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
verbose=False):
if not isinstance(milestones, (tuple, list)):
raise TypeError(
"The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
% type(milestones))
if not all([
milestones[i] < milestones[i + 1]
for i in range(len(milestones) - 1)
]):
raise ValueError('The elements of milestones must be incremented')
if gamma >= 1.0:
raise ValueError('gamma should be < 1.0.')
by_epoch=False,
**kwargs):
super(MultiStepDecay, self).__init__(
epochs, step_each_epoch, learning_rate, warmup_epoch,
warmup_start_lr, last_epoch, by_epoch)
self.milestones = [x * step_each_epoch for x in milestones]
self.gamma = gamma
super().__init__(learning_rate, last_epoch, verbose)
if self.by_epoch:
self.milestones = milestones
def get_lr(self):
for i in range(len(self.milestones)):
if self.last_epoch < self.milestones[i]:
return self.base_lr * (self.gamma**i)
return self.base_lr * (self.gamma**len(self.milestones))
def __call__(self):
learning_rate = lr.MultiStepDecay(
learning_rate=self.learning_rate,
milestones=self.milestones,
gamma=self.gamma,
last_epoch=self.last_epoch)
if self.warmup_steps > 0:
learning_rate = self.linear_warmup(learning_rate)
setattr(learning_rate, "by_epoch", self.by_epoch)
return learning_rate
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册