提交 94433634 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix Linear & support warmup start lr & support Cosine eta_min

Support setting warmup start lr and eta_min in Cosine.
Fix bug that Linear can not decay to end_lr when setting warmup.
上级 48494ec0
......@@ -26,6 +26,8 @@ class Linear(object):
epochs(int): The decay step size. It determines the decay cycle.
end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
power(float, optional): Power of polynomial. Default: 1.0.
warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
......@@ -36,28 +38,30 @@ class Linear(object):
end_lr=0.0,
power=1.0,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
**kwargs):
super(Linear, self).__init__()
self.learning_rate = learning_rate
self.epochs = epochs * step_each_epoch
self.steps = (epochs - warmup_epoch) * step_each_epoch
self.end_lr = end_lr
self.power = power
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
self.warmup_steps = round(warmup_epoch * step_each_epoch)
self.warmup_start_lr = warmup_start_lr
def __call__(self):
learning_rate = lr.PolynomialDecay(
learning_rate=self.learning_rate,
decay_steps=self.epochs,
decay_steps=self.steps,
end_lr=self.end_lr,
power=self.power,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
......@@ -71,6 +75,9 @@ class Cosine(object):
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
epochs(int): total training epochs
eta_min(float): Minimum learning rate. Default: 0.0.
warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
......@@ -78,25 +85,30 @@ class Cosine(object):
learning_rate,
step_each_epoch,
epochs,
eta_min=0.0,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
**kwargs):
super(Cosine, self).__init__()
self.learning_rate = learning_rate
self.T_max = step_each_epoch * epochs
self.T_max = (epochs - warmup_epoch) * step_each_epoch
self.eta_min = eta_min
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
self.warmup_steps = round(warmup_epoch * step_each_epoch)
self.warmup_start_lr = warmup_start_lr
def __call__(self):
learning_rate = lr.CosineAnnealingDecay(
learning_rate=self.learning_rate,
T_max=self.T_max,
eta_min=self.eta_min,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
......@@ -111,6 +123,8 @@ class Step(object):
step_size (int): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
......@@ -120,6 +134,7 @@ class Step(object):
step_each_epoch,
gamma,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
**kwargs):
super(Step, self).__init__()
......@@ -127,7 +142,8 @@ class Step(object):
self.learning_rate = learning_rate
self.gamma = gamma
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
self.warmup_steps = round(warmup_epoch * step_each_epoch)
self.warmup_start_lr = warmup_start_lr
def __call__(self):
learning_rate = lr.StepDecay(
......@@ -135,11 +151,11 @@ class Step(object):
step_size=self.step_size,
gamma=self.gamma,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
......@@ -152,6 +168,8 @@ class Piecewise(object):
boundaries(list): A list of steps numbers. The type of element in the list is python int.
values(list): A list of learning rate values that will be picked during different epoch boundaries.
The type of element in the list is python float.
warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
......@@ -160,24 +178,26 @@ class Piecewise(object):
decay_epochs,
values,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
**kwargs):
super(Piecewise, self).__init__()
self.boundaries = [step_each_epoch * e for e in decay_epochs]
self.values = values
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
self.warmup_steps = round(warmup_epoch * step_each_epoch)
self.warmup_start_lr = warmup_start_lr
def __call__(self):
learning_rate = lr.PiecewiseDecay(
boundaries=self.boundaries,
values=self.values,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.values[0],
last_epoch=self.last_epoch)
return learning_rate
......@@ -186,7 +206,7 @@ class Piecewise(object):
class MultiStepDecay(LRScheduler):
"""
Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
The algorithm can be described as the code below.
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5
milestones = [30, 50]
......@@ -200,15 +220,15 @@ class MultiStepDecay(LRScheduler):
Args:
learning_rate (float): The initial learning rate. It is a python float number.
milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``MultiStepDecay`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册