未验证 提交 f91bc7ba 编写于 作者: T Tingquan Gao 提交者: GitHub

perf: add parameter validation (#1249)

When using warm up, the total epoch num must be greater than warm up epoch num. Otherwise,
there will be raising warning and warm up epoch num will be set to total epoch num.
上级 283f9fa9
...@@ -11,12 +11,15 @@ ...@@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import (absolute_import, division, print_function, from __future__ import (absolute_import, division, print_function,
unicode_literals) unicode_literals)
from paddle.optimizer import lr from paddle.optimizer import lr
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
from ppcls.utils import logger
class Linear(object): class Linear(object):
""" """
...@@ -41,7 +44,11 @@ class Linear(object): ...@@ -41,7 +44,11 @@ class Linear(object):
warmup_start_lr=0.0, warmup_start_lr=0.0,
last_epoch=-1, last_epoch=-1,
**kwargs): **kwargs):
super(Linear, self).__init__() super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.steps = (epochs - warmup_epoch) * step_each_epoch self.steps = (epochs - warmup_epoch) * step_each_epoch
self.end_lr = end_lr self.end_lr = end_lr
...@@ -56,7 +63,8 @@ class Linear(object): ...@@ -56,7 +63,8 @@ class Linear(object):
decay_steps=self.steps, decay_steps=self.steps,
end_lr=self.end_lr, end_lr=self.end_lr,
power=self.power, power=self.power,
last_epoch=self.last_epoch) last_epoch=self.
last_epoch) if self.steps > 0 else self.learning_rate
if self.warmup_steps > 0: if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup( learning_rate = lr.LinearWarmup(
learning_rate=learning_rate, learning_rate=learning_rate,
...@@ -90,7 +98,11 @@ class Cosine(object): ...@@ -90,7 +98,11 @@ class Cosine(object):
warmup_start_lr=0.0, warmup_start_lr=0.0,
last_epoch=-1, last_epoch=-1,
**kwargs): **kwargs):
super(Cosine, self).__init__() super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.T_max = (epochs - warmup_epoch) * step_each_epoch self.T_max = (epochs - warmup_epoch) * step_each_epoch
self.eta_min = eta_min self.eta_min = eta_min
...@@ -103,7 +115,8 @@ class Cosine(object): ...@@ -103,7 +115,8 @@ class Cosine(object):
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
T_max=self.T_max, T_max=self.T_max,
eta_min=self.eta_min, eta_min=self.eta_min,
last_epoch=self.last_epoch) last_epoch=self.
last_epoch) if self.T_max > 0 else self.learning_rate
if self.warmup_steps > 0: if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup( learning_rate = lr.LinearWarmup(
learning_rate=learning_rate, learning_rate=learning_rate,
...@@ -132,12 +145,17 @@ class Step(object): ...@@ -132,12 +145,17 @@ class Step(object):
learning_rate, learning_rate,
step_size, step_size,
step_each_epoch, step_each_epoch,
epochs,
gamma, gamma,
warmup_epoch=0, warmup_epoch=0,
warmup_start_lr=0.0, warmup_start_lr=0.0,
last_epoch=-1, last_epoch=-1,
**kwargs): **kwargs):
super(Step, self).__init__() super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.step_size = step_each_epoch * step_size self.step_size = step_each_epoch * step_size
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.gamma = gamma self.gamma = gamma
...@@ -177,11 +195,16 @@ class Piecewise(object): ...@@ -177,11 +195,16 @@ class Piecewise(object):
step_each_epoch, step_each_epoch,
decay_epochs, decay_epochs,
values, values,
epochs,
warmup_epoch=0, warmup_epoch=0,
warmup_start_lr=0.0, warmup_start_lr=0.0,
last_epoch=-1, last_epoch=-1,
**kwargs): **kwargs):
super(Piecewise, self).__init__() super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.boundaries = [step_each_epoch * e for e in decay_epochs] self.boundaries = [step_each_epoch * e for e in decay_epochs]
self.values = values self.values = values
self.last_epoch = last_epoch self.last_epoch = last_epoch
...@@ -294,8 +317,7 @@ class MultiStepDecay(LRScheduler): ...@@ -294,8 +317,7 @@ class MultiStepDecay(LRScheduler):
raise ValueError('gamma should be < 1.0.') raise ValueError('gamma should be < 1.0.')
self.milestones = [x * step_each_epoch for x in milestones] self.milestones = [x * step_each_epoch for x in milestones]
self.gamma = gamma self.gamma = gamma
super(MultiStepDecay, self).__init__(learning_rate, last_epoch, super().__init__(learning_rate, last_epoch, verbose)
verbose)
def get_lr(self): def get_lr(self):
for i in range(len(self.milestones)): for i in range(len(self.milestones)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册